Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot method and new band names for Sen12MS dataset #320

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f619cfa
plotting method and new band names
nilsleh Dec 21, 2021
d985f31
Docs: get rid of "Return type: None" for procedures with no return va…
adamjstewart Dec 21, 2021
a3f5593
Add custom RasterDataset notebook (#283)
RitwikGupta Dec 21, 2021
5a57d6c
Lower headers one step on tutorial notebook
calebrob6 Dec 22, 2021
cbebc1e
Move DataModules to torchgeo.datamodules (#321)
adamjstewart Dec 24, 2021
3557848
Fix new mypy warnings (#324)
adamjstewart Dec 24, 2021
bae2bf1
Only test optional dataset deps on release branch (#327)
adamjstewart Dec 24, 2021
20459d7
Mark features added in #144 as new (#328)
adamjstewart Dec 24, 2021
9bc0bdb
Move imports out of global namespace (#325)
adamjstewart Dec 24, 2021
469d02c
Convert rST comments to directives
adamjstewart Dec 24, 2021
324f0f8
Use labeler bot to automatically label PRs (#323)
adamjstewart Dec 24, 2021
41533ff
Update hooks (#330)
ashnair1 Dec 26, 2021
33efc2c
fix imports in train.py (#332)
isaaccorley Dec 27, 2021
0abb780
More uniform name for labeler action
adamjstewart Dec 27, 2021
45df0f8
Add plotting method for COWC dataset (#300)
nilsleh Dec 28, 2021
0d4811b
functionality for learning on the prior with QR loss and ChesapeakeCV…
estherrolf Dec 28, 2021
7ad3194
Add plot method to ADVANCE, BeninSmallHolderCashews, and BigEarthNet …
calebrob6 Dec 28, 2021
f18c428
Quick fix for losses docs (#333)
calebrob6 Dec 28, 2021
0f969d3
testing real data
nilsleh Dec 29, 2021
a9f004e
Add plotting method for CV4A Kenya Crop Type Dataset (#312)
nilsleh Dec 29, 2021
72d7507
requested changes and plot igbp as mask
nilsleh Dec 30, 2021
61e10f7
Add plot method to Levir, and change directory path (#335)
nilsleh Dec 30, 2021
744078f
Refactor datamodule/model testing (#329)
adamjstewart Dec 30, 2021
452615d
plotting method and new band names
nilsleh Dec 21, 2021
eba5451
testing real data
nilsleh Dec 29, 2021
b5e1865
requested changes and plot igbp as mask
nilsleh Dec 30, 2021
581260b
conflict merge
nilsleh Dec 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/datasets/test_sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -83,6 +84,19 @@ def test_band_subsets(self) -> None:
x = ds[0]["image"]
assert x.shape[0] == len(bands)

def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
SEN12MS(bands=tuple(["OK", "BK"]))

def test_plot(self, dataset: SEN12MS) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()

sample = dataset[0]
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample, suptitle="prediction")
plt.close()


class TestSEN12MSDataModule:
@pytest.fixture(scope="class", params=["all", "s1", "s2-all", "s2-reduced"])
Expand Down
204 changes: 174 additions & 30 deletions torchgeo/datasets/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""SEN12MS dataset."""

import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import rasterio
Expand Down Expand Up @@ -69,50 +70,113 @@ class SEN12MS(VisionDataset):
This download will likely take several hours.
""" # noqa: E501

BAND_SETS: Dict[str, List[int]] = {
"all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"s1": [0, 1],
"s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"s2-reduced": [3, 4, 5, 9, 12, 13],
# BAND_SETS: Dict[str, List[int]] = {
# "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
# "s1": [0, 1],
# "s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
# "s2-reduced": [3, 4, 5, 9, 12, 13],
# }

BAND_SETS: Dict[str, Tuple[str, ...]] = {
"all": tuple(
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
[
"VV",
"VH",
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
]
),
"s1": tuple(["VV", "VH"]),
"s2-all": tuple(
[
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
]
),
"s2-reduced": tuple(["B02", "B03", "B04", "B08", "B10", "B11"]),
}

band_names = tuple(
[
"VV",
"VH",
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
]
)

RGB_BANDS = ["B04", "B03", "B02"]

filenames = [
"ROIs1158_spring_lc.tar.gz",
"ROIs1158_spring_s1.tar.gz",
"ROIs1158_spring_s2.tar.gz",
"ROIs1868_summer_lc.tar.gz",
"ROIs1868_summer_s1.tar.gz",
"ROIs1868_summer_s2.tar.gz",
"ROIs1970_fall_lc.tar.gz",
"ROIs1970_fall_s1.tar.gz",
"ROIs1970_fall_s2.tar.gz",
"ROIs2017_winter_lc.tar.gz",
"ROIs2017_winter_s1.tar.gz",
"ROIs2017_winter_s2.tar.gz",
# "ROIs1868_summer_lc.tar.gz",
# "ROIs1868_summer_s1.tar.gz",
# "ROIs1868_summer_s2.tar.gz",
# "ROIs1970_fall_lc.tar.gz",
# "ROIs1970_fall_s1.tar.gz",
# "ROIs1970_fall_s2.tar.gz",
# "ROIs2017_winter_lc.tar.gz",
# "ROIs2017_winter_s1.tar.gz",
# "ROIs2017_winter_s2.tar.gz",
"train_list.txt",
"test_list.txt",
]
light_filenames = [
"ROIs1158_spring",
"ROIs1868_summer",
"ROIs1970_fall",
"ROIs2017_winter",
# "ROIs1868_summer",
# "ROIs1970_fall",
# "ROIs2017_winter",
"train_list.txt",
"test_list.txt",
]
md5s = [
"6e2e8fa8b8cba77ddab49fd20ff5c37b",
"fba019bb27a08c1db96b31f718c34d79",
"d58af2c15a16f376eb3308dc9b685af2",
"2c5bd80244440b6f9d54957c6b1f23d4",
"01044b7f58d33570c6b57fec28a3d449",
"4dbaf72ecb704a4794036fe691427ff3",
"9b126a68b0e3af260071b3139cb57cee",
"19132e0aab9d4d6862fd42e8e6760847",
"b8f117818878da86b5f5e06400eb1866",
"0fa0420ef7bcfe4387c7e6fe226dc728",
"bb8cbfc16b95a4f054a3d5380e0130ed",
"3807545661288dcca312c9c538537b63",
# "2c5bd80244440b6f9d54957c6b1f23d4",
# "01044b7f58d33570c6b57fec28a3d449",
# "4dbaf72ecb704a4794036fe691427ff3",
# "9b126a68b0e3af260071b3139cb57cee",
# "19132e0aab9d4d6862fd42e8e6760847",
# "b8f117818878da86b5f5e06400eb1866",
# "0fa0420ef7bcfe4387c7e6fe226dc728",
# "bb8cbfc16b95a4f054a3d5380e0130ed",
# "3807545661288dcca312c9c538537b63",
"0a68d4e1eb24f128fccdb930000b2546",
"c7faad064001e646445c4c634169484d",
]
Expand All @@ -121,7 +185,7 @@ def __init__(
self,
root: str = "data",
split: str = "train",
bands: List[int] = BAND_SETS["all"],
bands: Tuple[str, ...] = BAND_SETS["all"],
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
checksum: bool = False,
) -> None:
Expand All @@ -147,9 +211,14 @@ def __init__(
"""
assert split in ["train", "test"]

self._validate_bands(bands)
self.band_indices = torch.tensor( # type: ignore[attr-defined]
[self.band_names.index(b) for b in bands]
).long()
self.bands = bands

self.root = root
self.split = split
self.bands = torch.tensor(bands).long() # type: ignore[attr-defined]
self.transforms = transforms
self.checksum = checksum

Expand Down Expand Up @@ -180,7 +249,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:

image = torch.cat(tensors=[s1, s2], dim=0) # type: ignore[attr-defined]
image = torch.index_select( # type: ignore[attr-defined]
image, dim=0, index=self.bands
image, dim=0, index=self.band_indices
)

sample: Dict[str, Tensor] = {"image": image, "mask": lc}
Expand Down Expand Up @@ -223,12 +292,30 @@ def _load_raster(self, filename: str, source: str) -> Tensor:
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor

def _validate_bands(self, bands: Tuple[str, ...]) -> None:
"""Validate list of bands.

Args:
bands: user-provided tuple of bands to load

Raises:
AssertionError: if ``bands`` is not a tuple
ValueError: if an invalid band name is provided
"""
assert isinstance(bands, tuple), "The list of bands must be a tuple"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")

def _check_integrity_light(self) -> bool:
"""Checks the integrity of the dataset structure.

Returns:
True if the dataset directories and split files are found, else False
"""
import pdb

pdb.set_trace()
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
for filename in self.light_filenames:
filepath = os.path.join(self.root, filename)
if not os.path.exists(filepath):
Expand All @@ -247,6 +334,63 @@ def _check_integrity(self) -> bool:
return False
return True

def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.

Args:
sample: a sample return by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure

Returns;
a matplotlib Figure with the rendered sample

.. versionadded:: 0.2
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")
import pdb

pdb.set_trace()
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
image, _ = sample["image"][rgb_indices, ...], sample["mask"]
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
ncols = 2

showing_predictions = "prediction" in sample
if showing_predictions:
_ = sample["prediction"]
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
ncols += 1

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))

axs[0].imshow(image.permute(1, 2, 0))
axs[0].axis("off")
# axs[1].imshow(mask)
# axs[1].axis("off")
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

if showing_predictions:
# axs[2].imshow(prediction)
axs[2].axis("off")

if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")
if showing_predictions:
axs[2].set_title("Prediction")

if suptitle is not None:
plt.suptitle(suptitle)

return fig


class SEN12MSDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the SEN12MS dataset.
Expand Down