Skip to content

Commit

Permalink
RasterDataset: fix band indexing bug (#1135)
Browse files Browse the repository at this point in the history
* Fix Landsat non-SR band specification

* Fix variable reference

* Fix test when no all_bands

* Fail during init instead

* Store variable

* all_bands -> default_bands

* Simplify logic

* fix mypy

* Make default_bands required
  • Loading branch information
adamjstewart authored Feb 23, 2023
1 parent 524030c commit 1155eae
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 25 deletions.
3 changes: 2 additions & 1 deletion tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class CustomVectorDataset(VectorDataset):

class CustomSentinelDataset(Sentinel2):
all_bands: List[str] = []
separate_files = False


class CustomNonGeoDataset(NonGeoDataset):
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
RasterDataset(str(tmp_path))

def test_no_allbands(self) -> None:
def test_no_all_bands(self) -> None:
root = os.path.join("tests", "data", "sentinel2")
bands = ["B04", "B03", "B02"]
transforms = nn.Identity()
Expand Down
12 changes: 9 additions & 3 deletions tests/datasets/test_landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset


class TestLandsat8:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> Landsat8:
@pytest.fixture(
params=[
["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"],
["SR_B4", "SR_B3", "SR_B2", "SR_QA_AEROSOL"],
]
)
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Landsat8:
root = os.path.join("tests", "data", "landsat8")
bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
bands = request.param
transforms = nn.Identity()
return Landsat8(root, bands=bands, transforms=transforms)

Expand Down
28 changes: 14 additions & 14 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(
super().__init__(transforms)

self.root = root
self.bands = bands or self.all_bands
self.cache = cache

# Populate the dataset index
Expand Down Expand Up @@ -367,21 +368,20 @@ def __init__(
f"No {self.__class__.__name__} data was found in '{root}'"
)

if bands and self.all_bands:
band_indexes = [self.all_bands.index(i) + 1 for i in bands]
self.bands = bands
assert len(band_indexes) == len(self.bands)
elif bands:
msg = (
f"{self.__class__.__name__} is missing an `all_bands` attribute,"
" so `bands` cannot be specified."
)
raise AssertionError(msg)
else:
band_indexes = None
self.bands = self.all_bands
if not self.separate_files:
self.band_indexes = None
if self.bands:
if self.all_bands:
self.band_indexes = [
self.all_bands.index(i) + 1 for i in self.bands
]
else:
msg = (
f"{self.__class__.__name__} is missing an `all_bands` "
"attribute, so `bands` cannot be specified."
)
raise AssertionError(msg)

self.band_indexes = band_indexes
self._crs = cast(CRS, crs)
self.res = cast(float, res)

Expand Down
28 changes: 21 additions & 7 deletions torchgeo/datasets/landsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Landsat datasets."""

import abc
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence

import matplotlib.pyplot as plt
from rasterio.crs import CRS
Expand Down Expand Up @@ -49,6 +49,11 @@ class Landsat(RasterDataset, abc.ABC):

separate_files = True

@property
@abc.abstractmethod
def default_bands(self) -> List[str]:
"""Bands to load by default."""

def __init__(
self,
root: str = "data",
Expand All @@ -74,7 +79,7 @@ def __init__(
Raises:
FileNotFoundError: if no files are found in ``root``
"""
bands = bands or self.all_bands
bands = bands or self.default_bands
self.filename_glob = self.filename_glob.format(bands[0])

super().__init__(root, crs, res, bands, transforms, cache)
Expand Down Expand Up @@ -133,7 +138,7 @@ class Landsat1(Landsat):

filename_glob = "LM01_*_{}.*"

all_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"]
default_bands = ["SR_B4", "SR_B5", "SR_B6", "SR_B7"]
rgb_bands = ["SR_B6", "SR_B5", "SR_B4"]


Expand All @@ -154,7 +159,7 @@ class Landsat4MSS(Landsat):

filename_glob = "LM04_*_{}.*"

all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"]
default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4"]
rgb_bands = ["SR_B3", "SR_B2", "SR_B1"]


Expand All @@ -163,7 +168,7 @@ class Landsat4TM(Landsat):

filename_glob = "LT04_*_{}.*"

all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
default_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
rgb_bands = ["SR_B3", "SR_B2", "SR_B1"]


Expand All @@ -184,7 +189,16 @@ class Landsat7(Landsat):

filename_glob = "LE07_*_{}.*"

all_bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "SR_B8"]
default_bands = [
"SR_B1",
"SR_B2",
"SR_B3",
"SR_B4",
"SR_B5",
"SR_B6",
"SR_B7",
"SR_B8",
]
rgb_bands = ["SR_B3", "SR_B2", "SR_B1"]


Expand All @@ -193,7 +207,7 @@ class Landsat8(Landsat):

filename_glob = "LC08_*_{}.*"

all_bands = [
default_bands = [
"SR_B1",
"SR_B2",
"SR_B3",
Expand Down

0 comments on commit 1155eae

Please sign in to comment.