Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot method to (most) DataModules #814

Merged
merged 17 commits into from
Oct 18, 2022
8 changes: 8 additions & 0 deletions tests/datamodules/test_fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import FAIR1MDataModule
from torchgeo.datasets import unbind_samples


class TestFAIR1MDataModule:
Expand All @@ -32,3 +34,9 @@ def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:

def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: FAIR1MDataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
8 changes: 8 additions & 0 deletions tests/datamodules/test_loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import LoveDADataModule
from torchgeo.datasets import unbind_samples


class TestLoveDADataModule:
Expand All @@ -32,3 +34,9 @@ def test_val_dataloader(self, datamodule: LoveDADataModule) -> None:

def test_test_dataloader(self, datamodule: LoveDADataModule) -> None:
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: LoveDADataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
8 changes: 8 additions & 0 deletions tests/datamodules/test_potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import Potsdam2DDataModule
from torchgeo.datasets import unbind_samples


class TestPotsdam2DDataModule:
Expand Down Expand Up @@ -34,3 +36,9 @@ def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None:

def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: Potsdam2DDataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
8 changes: 8 additions & 0 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import USAVarsDataModule
from torchgeo.datasets import unbind_samples


class TestUSAVarsDataModule:
Expand Down Expand Up @@ -38,3 +40,9 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.test_dataloader()) == 1
sample = next(iter(datamodule.test_dataloader()))
assert sample["image"].shape[0] == datamodule.batch_size

def test_plot(self, datamodule: USAVarsDataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
8 changes: 8 additions & 0 deletions tests/datamodules/test_vaihingen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import Vaihingen2DDataModule
from torchgeo.datasets import unbind_samples


class TestVaihingen2DDataModule:
Expand Down Expand Up @@ -34,3 +36,9 @@ def test_val_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:

def test_test_dataloader(self, datamodule: Vaihingen2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: Vaihingen2DDataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
8 changes: 8 additions & 0 deletions tests/datamodules/test_xview2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import XView2DataModule
from torchgeo.datasets import unbind_samples


class TestXView2DataModule:
Expand Down Expand Up @@ -34,3 +36,9 @@ def test_val_dataloader(self, datamodule: XView2DataModule) -> None:

def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: XView2DataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
32 changes: 27 additions & 5 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def test_no_logger(self) -> None:
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"in_channels": 1,
"in_channels": 13,
"loss": "ce",
"num_classes": 2,
"num_classes": 10,
"weights": "random",
}

Expand Down Expand Up @@ -124,6 +124,17 @@ def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(EuroSATDataModule, "plot")
datamodule = EuroSATDataModule(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
model = ClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)


class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -179,9 +190,9 @@ def test_no_logger(self) -> None:
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"in_channels": 1,
"loss": "ce",
"num_classes": 1,
"in_channels": 14,
"loss": "bce",
"num_classes": 19,
"weights": "random",
}

Expand All @@ -190,3 +201,14 @@ def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(**model_kwargs)

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(BigEarthNetDataModule, "plot")
datamodule = BigEarthNetDataModule(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)
16 changes: 16 additions & 0 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, Type, cast

import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

Expand Down Expand Up @@ -63,3 +64,18 @@ def test_invalid_model(self) -> None:
match = "module 'torchvision.models' has no attribute 'invalid_model'"
with pytest.raises(AttributeError, match=match):
RegressionTask(model="invalid_model", pretrained=False)

@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {"model": "resnet18", "pretrained": False}

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(COWCCountingDataModule, "plot")
datamodule = COWCCountingDataModule(
root="tests/data/cowc_counting", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)
15 changes: 13 additions & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def model_kwargs(self) -> Dict[Any, Any]:
"segmentation_model": "unet",
"encoder_name": "resnet18",
"encoder_weights": None,
"in_channels": 1,
"num_classes": 2,
"in_channels": 3,
"num_classes": 6,
"loss": "ce",
"ignore_index": 0,
}
Expand Down Expand Up @@ -134,3 +134,14 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None:
match = "ignore_index has no effect on training when loss='jaccard'"
with pytest.warns(UserWarning, match=match):
SemanticSegmentationTask(**model_kwargs)

def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(LandCoverAIDataModule, "plot")
datamodule = LandCoverAIDataModule(
root="tests/data/landcoverai", batch_size=1, num_workers=0
)
model = SemanticSegmentationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.validate(model=model, datamodule=datamodule)
2 changes: 1 addition & 1 deletion torchgeo/datamodules/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class COWCCountingDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the COWC Counting dataset."""

def __init__(
self, seed: int, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CycloneDataModule(pl.LightningDataModule):
"""

def __init__(
self, seed: int, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for NASA Cyclone based DataLoaders.

Expand Down
8 changes: 8 additions & 0 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
Expand Down Expand Up @@ -122,3 +123,10 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`.

.. versionadded:: 0.4
"""
return self.test_dataset.plot(*args, **kwargs)
12 changes: 10 additions & 2 deletions torchgeo/datamodules/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, List, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import Tensor
Expand Down Expand Up @@ -84,9 +85,9 @@ def setup(self, stage: Optional[str] = None) -> None:
Args:
stage: stage to set up
"""
dataset = FAIR1M(transforms=self.preprocess, **self.kwargs)
self.dataset = FAIR1M(transforms=self.preprocess, **self.kwargs)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)

def train_dataloader(self) -> DataLoader[Any]:
Expand Down Expand Up @@ -130,3 +131,10 @@ def test_dataloader(self) -> DataLoader[Any]:
shuffle=False,
collate_fn=collate_fn,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.FAIR1M.plot`.

.. versionadded:: 0.4
"""
return self.dataset.plot(*args, **kwargs)
8 changes: 8 additions & 0 deletions torchgeo/datamodules/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -121,3 +122,10 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.LoveDA.plot`.

.. versionadded:: 0.4
"""
return self.train_dataset.plot(*args, **kwargs)
Loading