Skip to content

Commit

Permalink
Increase coverage of RasterDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 7, 2021
1 parent b35a202 commit 1590397
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Dict

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import ConcatDataset

from torchgeo.datasets import (
BoundingBox,
GeoDataset,
Landsat8,
RasterDataset,
VectorDataset,
VisionDataset,
ZipDataset,
)
from torchgeo.transforms import Identity


class CustomGeoDataset(GeoDataset):
Expand Down Expand Up @@ -87,6 +92,21 @@ def test_add_vision(self, dataset: GeoDataset) -> None:


class TestRasterDataset:
@pytest.fixture(params=[True, False])
def dataset(self, request: SubRequest) -> Landsat8:
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
crs = CRS.from_epsg(3005)
transforms = Identity()
cache = request.param
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)

def test_getitem(self, dataset: Landsat8) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
RasterDataset(str(tmp_path))
Expand Down

0 comments on commit 1590397

Please sign in to comment.