diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 556de8c6587..0528dfe78c4 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -182,9 +182,9 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1 - cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1 - length = rows * cols * 2 + rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1 + cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1 + length = rows * cols * 2 # two items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: @@ -194,12 +194,35 @@ def test_roi(self, dataset: CustomGeoDataset) -> None: assert query in roi def test_small_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 1, 0, 1, 0, 1)) + sampler = GridGeoSampler(ds, 2, 10) + assert len(sampler) == 0 + + def test_tiles_side_by_side(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (20, 21, 20, 21, 20, 21)) + ds.index.insert(0, (0, 10, 10, 20, 0, 10)) sampler = GridGeoSampler(ds, 2, 10) - for _ in sampler: - continue + for bbox in sampler: + assert bbox.area > 0 + + def test_integer_multiple(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS) + iterator = iter(sampler) + assert len(sampler) == 1 + assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10) + + def test_float_multiple(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 6, 0, 5, 0, 10)) + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + iterator = iter(sampler) + assert len(sampler) == 2 + assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) + assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10) @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 92930a24382..e063d9ecbd4 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,6 +4,7 @@ """TorchGeo samplers.""" import abc +import math from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch @@ -146,7 +147,7 @@ def __len__(self) -> int: class GridGeoSampler(GeoSampler): - """Samples elements in a grid-like fashion. + r"""Samples elements in a grid-like fashion. This is particularly useful during evaluation when you want to make predictions for an entire region of interest. You want to minimize the amount of redundant @@ -158,6 +159,21 @@ class GridGeoSampler(GeoSampler): The overlap between each chip (``chip_size - stride``) should be approximately equal to the `receptive field `_ of the CNN. + + Note that the stride of the final set of chips in each row/column may be adjusted so + that the entire :term:`tile` is sampled without exceeding the bounds of the dataset. + + Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of + the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number + of output rows/columns sampled from each tile. :math:`o` can then be computed as: + + .. math:: + + o = \left\lceil \frac{i - k}{s} \right\rceil + 1 + + This is almost identical to relationship 5 in + https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor + because we want to include the final remaining chip. """ def __init__( @@ -200,8 +216,8 @@ def __init__( for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) if ( - bounds.maxx - bounds.minx > self.size[1] - and bounds.maxy - bounds.miny > self.size[0] + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] ): self.hits.append(hit) @@ -209,8 +225,14 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = ( + math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0]) + + 1 + ) + cols = ( + math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1]) + + 1 + ) self.length += rows * cols def __iter__(self) -> Iterator[BoundingBox]: @@ -223,8 +245,14 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = ( + math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0]) + + 1 + ) + cols = ( + math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1]) + + 1 + ) mint = bounds.mint maxt = bounds.maxt @@ -233,11 +261,17 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] + if maxy > bounds.maxy: + maxy = bounds.maxy + miny = bounds.maxy - self.size[0] # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] + if maxx > bounds.maxx: + maxx = bounds.maxx + minx = bounds.maxx - self.size[1] yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)