From 1155eaefeb5c0b8f4800767aa5777b499095bc8c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 23 Feb 2023 13:39:20 -0700 Subject: [PATCH] RasterDataset: fix band indexing bug (#1135) * Fix Landsat non-SR band specification * Fix variable reference * Fix test when no all_bands * Fail during init instead * Store variable * all_bands -> default_bands * Simplify logic * fix mypy * Make default_bands required --- tests/datasets/test_geo.py | 3 ++- tests/datasets/test_landsat.py | 12 +++++++++--- torchgeo/datasets/geo.py | 28 ++++++++++++++-------------- torchgeo/datasets/landsat.py | 28 +++++++++++++++++++++------- 4 files changed, 46 insertions(+), 25 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index a24dd3653f2..158117c9d88 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -52,6 +52,7 @@ class CustomVectorDataset(VectorDataset): class CustomSentinelDataset(Sentinel2): all_bands: List[str] = [] + separate_files = False class CustomNonGeoDataset(NonGeoDataset): @@ -214,7 +215,7 @@ def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"): RasterDataset(str(tmp_path)) - def test_no_allbands(self) -> None: + def test_no_all_bands(self) -> None: root = os.path.join("tests", "data", "sentinel2") bands = ["B04", "B03", "B02"] transforms = nn.Identity() diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index bc1ff2c8aea..e82151a05e6 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -8,6 +8,7 @@ import pytest import torch import torch.nn as nn +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS @@ -15,10 +16,15 @@ class TestLandsat8: - @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch) -> Landsat8: + @pytest.fixture( + params=[ + ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"], + ["SR_B4", "SR_B3", "SR_B2", "SR_QA_AEROSOL"], + ] + ) + def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Landsat8: root = os.path.join("tests", "data", "landsat8") - bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] + bands = request.param transforms = nn.Identity() return Landsat8(root, bands=bands, transforms=transforms) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 96ae41e20a3..aede7e6a405 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -323,6 +323,7 @@ def __init__( super().__init__(transforms) self.root = root + self.bands = bands or self.all_bands self.cache = cache # Populate the dataset index @@ -367,21 +368,20 @@ def __init__( f"No {self.__class__.__name__} data was found in '{root}'" ) - if bands and self.all_bands: - band_indexes = [self.all_bands.index(i) + 1 for i in bands] - self.bands = bands - assert len(band_indexes) == len(self.bands) - elif bands: - msg = ( - f"{self.__class__.__name__} is missing an `all_bands` attribute," - " so `bands` cannot be specified." - ) - raise AssertionError(msg) - else: - band_indexes = None - self.bands = self.all_bands + if not self.separate_files: + self.band_indexes = None + if self.bands: + if self.all_bands: + self.band_indexes = [ + self.all_bands.index(i) + 1 for i in self.bands + ] + else: + msg = ( + f"{self.__class__.__name__} is missing an `all_bands` " + "attribute, so `bands` cannot be specified." + ) + raise AssertionError(msg) - self.band_indexes = band_indexes self._crs = cast(CRS, crs) self.res = cast(float, res) diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index 48754d7f1e2..c1bd6630a52 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -4,7 +4,7 @@ """Landsat datasets.""" import abc -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence import matplotlib.pyplot as plt from rasterio.crs import CRS @@ -49,6 +49,11 @@ class Landsat(RasterDataset, abc.ABC): separate_files = True + @property + @abc.abstractmethod + def default_bands(self) -> List[str]: + """Bands to load by default.""" + def __init__( self, root: str = "data", @@ -74,7 +79,7 @@ def __init__( Raises: FileNotFoundError: if no files are found in ``root`` """ - bands = bands or self.all_bands + bands = bands or self.default_bands self.filename_glob = self.filename_glob.format(bands[0]) super().__init__(root, crs, res, bands, transforms, cache) @@ -133,7 +138,7 @@ class Landsat1(Landsat): filename_glob = "LM01_*_{}.*" - all_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"] + default_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"] rgb_bands = ["SR_B6", "SR_B5", "SR_B4"] @@ -154,7 +159,7 @@ class Landsat4MSS(Landsat): filename_glob = "LM04_*_{}.*" - all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"] + default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"] rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] @@ -163,7 +168,7 @@ class Landsat4TM(Landsat): filename_glob = "LT04_*_{}.*" - all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] + default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"] rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] @@ -184,7 +189,16 @@ class Landsat7(Landsat): filename_glob = "LE07_*_{}.*" - all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "SR_B8"] + default_bands = [ + "SR_B1", + "SR_B2", + "SR_B3", + "SR_B4", + "SR_B5", + "SR_B6", + "SR_B7", + "SR_B8", + ] rgb_bands = ["SR_B3", "SR_B2", "SR_B1"] @@ -193,7 +207,7 @@ class Landsat8(Landsat): filename_glob = "LC08_*_{}.*" - all_bands = [ + default_bands = [ "SR_B1", "SR_B2", "SR_B3",