diff --git a/.gitignore b/.gitignore index d490e65b1da..29e2b022a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ /output/ *.pdf /results/ +*.aux.xml # Spack .spack-env/ diff --git a/tests/data/raster/data.py b/tests/data/raster/data.py index cc00aafef9b..517649607ba 100755 --- a/tests/data/raster/data.py +++ b/tests/data/raster/data.py @@ -2,41 +2,99 @@ # Licensed under the MIT License. import os +from typing import Optional import numpy as np -import rasterio -import rasterio.transform -from torchvision.datasets.utils import calculate_md5 +import rasterio as rio +from rasterio.transform import from_bounds +from rasterio.warp import calculate_default_transform, reproject +RES = [2, 4, 8] +EPSG = [4087, 4326, 32631] +SIZE = 16 -def generate_test_data(fn: str) -> str: - """Creates test data with uint32 datatype. - Args: - fn (str): Filename to write +def write_raster( + res: int = RES[0], + epsg: int = EPSG[0], + dtype: str = "uint8", + path: Optional[str] = None, +) -> None: + """Write a raster file. - Returns: - str: md5 hash of created archive + Args: + res: Resolution. + epsg: EPSG of file. + dtype: Data type. + path: File path. """ + size = SIZE // res profile = { "driver": "GTiff", - "dtype": "uint32", + "dtype": dtype, "count": 1, - "crs": "epsg:4326", - "transform": rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1), - "height": 4, - "width": 4, - "compress": "lzw", - "predictor": 2, + "crs": f"epsg:{epsg}", + "transform": from_bounds(0, 0, SIZE, SIZE, size, size), + "height": size, + "width": size, + "nodata": 0, } - with rasterio.open(fn, "w", **profile) as f: - f.write(np.random.randint(0, 256, size=(1, 4, 4))) + if path is None: + name = f"res_{res}_epsg_{epsg}" + path = os.path.join(name, f"{name}.tif") + + directory = os.path.dirname(path) + os.makedirs(directory, exist_ok=True) + + with rio.open(path, "w", **profile) as f: + x = np.ones((1, size, size)) + f.write(x) + - md5: str = calculate_md5(fn) - return md5 +def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None: + """Reproject a raster file. + + Args: + res: Resolution. + src_epsg: EPSG of source file. + dst_epsg: EPSG of destination file. + """ + src_name = f"res_{res}_epsg_{src_epsg}" + src_path = os.path.join(src_name, f"{src_name}.tif") + with rio.open(src_path) as src: + dst_crs = f"epsg:{dst_epsg}" + transform, width, height = calculate_default_transform( + src.crs, dst_crs, src.width, src.height, *src.bounds + ) + profile = src.profile.copy() + profile.update( + {"crs": dst_crs, "transform": transform, "width": width, "height": height} + ) + dst_name = f"res_{res}_epsg_{dst_epsg}" + os.makedirs(dst_name, exist_ok=True) + dst_path = os.path.join(dst_name, f"{dst_name}.tif") + with rio.open(dst_path, "w", **profile) as dst: + reproject( + source=rio.band(src, 1), + destination=rio.band(dst, 1), + src_transform=src.transform, + src_crs=src.crs, + dst_transform=dst.transform, + dst_crs=dst.crs, + ) if __name__ == "__main__": - md5_hash = generate_test_data(os.path.join(os.getcwd(), "test0.tif")) - print(md5_hash) + for res in RES: + src_epsg = EPSG[0] + write_raster(res, src_epsg) + + for dst_epsg in EPSG[1:]: + reproject_raster(res, src_epsg, dst_epsg) + + for dtype in ["uint16", "uint32"]: + path = os.path.join(dtype, f"{dtype}.tif") + write_raster(dtype=dtype, path=path) + with open(os.path.join(dtype, "corrupted.tif"), "w") as f: + f.write("not a tif file\n") diff --git a/tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif b/tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif new file mode 100644 index 00000000000..caa023eed76 Binary files /dev/null and b/tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif differ diff --git a/tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif b/tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif new file mode 100644 index 00000000000..752611e87bb Binary files /dev/null and b/tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif differ diff --git a/tests/data/raster/res_2_epsg_4326/res_2_epsg_4326.tif b/tests/data/raster/res_2_epsg_4326/res_2_epsg_4326.tif new file mode 100644 index 00000000000..b3362b8ac22 Binary files /dev/null and b/tests/data/raster/res_2_epsg_4326/res_2_epsg_4326.tif differ diff --git a/tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif b/tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif new file mode 100644 index 00000000000..dd69237d92b Binary files /dev/null and b/tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif differ diff --git a/tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif b/tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif new file mode 100644 index 00000000000..c9ce35ee860 Binary files /dev/null and b/tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif differ diff --git a/tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif b/tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif new file mode 100644 index 00000000000..eb2d7a9c66d Binary files /dev/null and b/tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif differ diff --git a/tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif b/tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif new file mode 100644 index 00000000000..92d838ab2c2 Binary files /dev/null and b/tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif differ diff --git a/tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif b/tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif new file mode 100644 index 00000000000..0b989058f16 Binary files /dev/null and b/tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif differ diff --git a/tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif b/tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif new file mode 100644 index 00000000000..aa0a7318bcc Binary files /dev/null and b/tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif differ diff --git a/tests/data/raster/test0.tif b/tests/data/raster/test0.tif deleted file mode 100644 index 84df1f7cb14..00000000000 Binary files a/tests/data/raster/test0.tif and /dev/null differ diff --git a/tests/data/raster/uint16/corrupted.tif b/tests/data/raster/uint16/corrupted.tif new file mode 100644 index 00000000000..42e548ffea8 --- /dev/null +++ b/tests/data/raster/uint16/corrupted.tif @@ -0,0 +1 @@ +not a tif file diff --git a/tests/data/raster/uint16/uint16.tif b/tests/data/raster/uint16/uint16.tif new file mode 100644 index 00000000000..05e38bdc42e Binary files /dev/null and b/tests/data/raster/uint16/uint16.tif differ diff --git a/tests/data/raster/uint32/corrupted.tif b/tests/data/raster/uint32/corrupted.tif new file mode 100644 index 00000000000..42e548ffea8 --- /dev/null +++ b/tests/data/raster/uint32/corrupted.tif @@ -0,0 +1 @@ +not a tif file diff --git a/tests/data/raster/uint32/uint32.tif b/tests/data/raster/uint32/uint32.tif new file mode 100644 index 00000000000..e42c41fc470 Binary files /dev/null and b/tests/data/raster/uint32/uint32.tif differ diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 259c583ebe2..cf4ef25d880 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -30,7 +30,7 @@ class CustomGeoDataset(GeoDataset): def __init__( self, bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), - crs: CRS = CRS.from_epsg(3005), + crs: CRS = CRS.from_epsg(4087), res: float = 1, ) -> None: super().__init__() @@ -74,7 +74,7 @@ def test_getitem(self, dataset: GeoDataset) -> None: def test_len(self, dataset: GeoDataset) -> None: assert len(dataset) == 1 - @pytest.mark.parametrize("crs", [CRS.from_epsg(3005), CRS.from_epsg(32616)]) + @pytest.mark.parametrize("crs", [CRS.from_epsg(4087), CRS.from_epsg(32631)]) def test_crs(self, dataset: GeoDataset, crs: CRS) -> None: dataset.crs = crs @@ -157,7 +157,7 @@ class TestRasterDataset: def naip(self, request: SubRequest) -> NAIP: root = os.path.join("tests", "data", "naip") bands = request.param[0] - crs = CRS.from_epsg(3005) + crs = CRS.from_epsg(4087) transforms = nn.Identity() cache = request.param[1] return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache) @@ -178,11 +178,6 @@ def sentinel(self, request: SubRequest) -> Sentinel2: cache = request.param[1] return Sentinel2(root, bands=bands, transforms=transforms, cache=cache) - @pytest.fixture() - def custom_dtype_ds(self) -> RasterDataset: - root = os.path.join("tests", "data", "raster") - return RasterDataset(root) - def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] assert isinstance(x, dict) @@ -197,8 +192,11 @@ def test_getitem_separate_files(self, sentinel: Sentinel2) -> None: assert isinstance(x["image"], torch.Tensor) assert len(sentinel.bands) == x["image"].shape[0] - def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None: - x = custom_dtype_ds[custom_dtype_ds.bounds] + @pytest.mark.parametrize("dtype", ["uint16", "uint32"]) + def test_getitem_uint_dtype(self, dtype: str) -> None: + root = os.path.join("tests", "data", "raster", dtype) + ds = RasterDataset(root) + x = ds[ds.bounds] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert x["image"].dtype == torch.float32 @@ -377,14 +375,15 @@ def test_str(self, dataset: NonGeoClassificationDataset) -> None: class TestIntersectionDataset: @pytest.fixture(scope="class") def dataset(self) -> IntersectionDataset: - ds1 = CustomGeoDataset() - ds2 = CustomGeoDataset() + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326")) transforms = nn.Identity() return IntersectionDataset(ds1, ds2, transforms=transforms) def test_getitem(self, dataset: IntersectionDataset) -> None: - query = BoundingBox(0, 1, 2, 3, 4, 5) - assert dataset[query] == {"index": query} + query = dataset.bounds + sample = dataset[query] + assert isinstance(sample["image"], torch.Tensor) def test_len(self, dataset: IntersectionDataset) -> None: assert len(dataset) == 1 @@ -403,27 +402,69 @@ def test_nongeo_dataset(self) -> None: ): IntersectionDataset(ds1, ds2) # type: ignore[arg-type] - def test_different_crs(self) -> None: - ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005)) - ds2 = CustomGeoDataset( - BoundingBox( - -3547229.913123814, - 6360089.518213182, - -3547229.913123814, - 6360089.518213182, - -3547229.913123814, - 6360089.518213182, - ), - crs=CRS.from_epsg(32616), - ) + def test_different_crs_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 1 - - def test_different_res(self) -> None: - ds1 = CustomGeoDataset(res=1) - ds2 = CustomGeoDataset(res=2) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = (ds1 & ds2) & ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = ds1 & (ds2 & ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 1 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = (ds1 & ds2) & ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = ds1 & (ds2 & ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) def test_no_overlap(self) -> None: ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5)) @@ -433,7 +474,7 @@ def test_no_overlap(self) -> None: IntersectionDataset(ds1, ds2) def test_invalid_query(self, dataset: IntersectionDataset) -> None: - query = BoundingBox(0, 0, 0, 0, 0, 0) + query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( IndexError, match="query: .* not found in index with bounds:" ): @@ -443,14 +484,15 @@ def test_invalid_query(self, dataset: IntersectionDataset) -> None: class TestUnionDataset: @pytest.fixture(scope="class") def dataset(self) -> UnionDataset: - ds1 = CustomGeoDataset(bounds=BoundingBox(0, 1, 0, 1, 0, 1)) - ds2 = CustomGeoDataset(bounds=BoundingBox(2, 3, 2, 3, 2, 3)) + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326")) transforms = nn.Identity() return UnionDataset(ds1, ds2, transforms=transforms) def test_getitem(self, dataset: UnionDataset) -> None: - query = BoundingBox(0, 1, 0, 1, 0, 1) - assert dataset[query] == {"index": query} + query = dataset.bounds + sample = dataset[query] + assert isinstance(sample["image"], torch.Tensor) def test_len(self, dataset: UnionDataset) -> None: assert len(dataset) == 2 @@ -461,6 +503,76 @@ def test_str(self, dataset: UnionDataset) -> None: assert "bbox: BoundingBox" in out assert "size: 2" in out + def test_different_crs_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds = UnionDataset(ds1, ds2) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == 1 + assert len(ds) == 2 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = (ds1 | ds2) | ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = ds1 | (ds2 | ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds = UnionDataset(ds1, ds2) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == 1 + assert len(ds) == 2 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = (ds1 | ds2) | ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = ds1 | (ds2 | ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + def test_nongeo_dataset(self) -> None: ds1 = CustomNonGeoDataset() ds2 = CustomNonGeoDataset() @@ -473,22 +585,8 @@ def test_nongeo_dataset(self) -> None: with pytest.raises(ValueError, match=msg): UnionDataset(ds3, ds1) # type: ignore[arg-type] - def test_different_crs(self) -> None: - ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005)) - ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616)) - ds = UnionDataset(ds1, ds2) - assert ds.crs == ds1.crs - assert len(ds) == 2 - - def test_different_res(self) -> None: - ds1 = CustomGeoDataset(res=1) - ds2 = CustomGeoDataset(res=2) - ds = UnionDataset(ds1, ds2) - assert ds.res == ds1.res - assert len(ds) == 2 - def test_invalid_query(self, dataset: UnionDataset) -> None: - query = BoundingBox(4, 5, 4, 5, 4, 5) + query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( IndexError, match="query: .* not found in index with bounds:" ): diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d4a81f73b46..dc2b5fa1388 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -23,7 +23,6 @@ from rasterio.crs import CRS from rasterio.io import DatasetReader from rasterio.vrt import WarpedVRT -from rasterio.windows import from_bounds from rtree.index import Index, Property from torch import Tensor from torch.utils.data import Dataset @@ -73,9 +72,8 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): dataset = landsat7 | landsat8 """ - #: Resolution of the dataset in units of CRS. - res: float - _crs: CRS + _crs = CRS.from_epsg(4326) + _res = 0.0 # NOTE: according to the Python docs: # @@ -213,12 +211,10 @@ def bounds(self) -> BoundingBox: @property def crs(self) -> CRS: - """:term:`coordinate reference system (CRS)` for the dataset. + """:term:`coordinate reference system (CRS)` of the dataset. Returns: - the :term:`coordinate reference system (CRS)` - - .. versionadded:: 0.2 + The :term:`coordinate reference system (CRS)`. """ return self._crs @@ -229,17 +225,16 @@ def crs(self, new_crs: CRS) -> None: If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index. Args: - new_crs: new :term:`coordinate reference system (CRS)` - - .. versionadded:: 0.2 + new_crs: New :term:`coordinate reference system (CRS)`. """ - if new_crs == self._crs: + if new_crs == self.crs: return + print(f"Converting {self.__class__.__name__} CRS from {self.crs} to {new_crs}") new_index = Index(interleaved=False, properties=Property(dimension=3)) project = pyproj.Transformer.from_crs( - pyproj.CRS(str(self._crs)), pyproj.CRS(str(new_crs)), always_xy=True + pyproj.CRS(str(self.crs)), pyproj.CRS(str(new_crs)), always_xy=True ).transform for hit in self.index.intersection(self.index.bounds, objects=True): old_minx, old_maxx, old_miny, old_maxy, mint, maxt = hit.bounds @@ -252,6 +247,28 @@ def crs(self, new_crs: CRS) -> None: self._crs = new_crs self.index = new_index + @property + def res(self) -> float: + """Resolution of the dataset in units of CRS. + + Returns: + The resolution of the dataset. + """ + return self._res + + @res.setter + def res(self, new_res: float) -> None: + """Change the resolution of a GeoDataset. + + Args: + new_res: New resolution. + """ + if new_res == self.res: + return + + print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}") + self._res = new_res + class RasterDataset(GeoDataset): """Abstract base class for :class:`GeoDataset` stored as raster files.""" @@ -399,7 +416,7 @@ def __init__( raise AssertionError(msg) self._crs = cast(CRS, crs) - self.res = cast(float, res) + self._res = cast(float, res) def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -477,22 +494,7 @@ def _merge_files( vrt_fhs = [self._load_warp_file(fp) for fp in filepaths] bounds = (query.minx, query.miny, query.maxx, query.maxy) - if len(vrt_fhs) == 1: - src = vrt_fhs[0] - out_width = round((query.maxx - query.minx) / self.res) - out_height = round((query.maxy - query.miny) / self.res) - count = len(band_indexes) if band_indexes else src.count - out_shape = (count, out_height, out_width) - dest = src.read( - indexes=band_indexes, - out_shape=out_shape, - window=from_bounds(*bounds, src.transform), - boundless=True, - ) - else: - dest, _ = rasterio.merge.merge( - vrt_fhs, bounds, self.res, indexes=band_indexes - ) + dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes) # fix numpy dtypes which are not supported by pytorch tensors if dest.dtype == np.uint16: @@ -574,7 +576,6 @@ def __init__( super().__init__(transforms) self.root = root - self.res = res self.label_name = label_name # Populate the dataset index @@ -605,6 +606,7 @@ def __init__( raise FileNotFoundError(msg) self._crs = crs + self._res = res def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -844,23 +846,9 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError("IntersectionDataset only supports GeoDatasets") - self._crs = dataset1.crs + self.crs = dataset1.crs self.res = dataset1.res - # Force dataset2 to have the same CRS/res as dataset1 - if dataset1.crs != dataset2.crs: - print( - f"Converting {dataset2.__class__.__name__} CRS from " - f"{dataset2.crs} to {dataset1.crs}" - ) - dataset2.crs = dataset1.crs - if dataset1.res != dataset2.res: - print( - f"Converting {dataset2.__class__.__name__} resolution from " - f"{dataset2.res} to {dataset1.res}" - ) - dataset2.res = dataset1.res - # Merge dataset indices into a single index self._merge_dataset_indices() @@ -917,6 +905,46 @@ def __str__(self) -> str: bbox: {self.bounds} size: {len(self)}""" + @property + def crs(self) -> CRS: + """:term:`coordinate reference system (CRS)` of both datasets. + + Returns: + The :term:`coordinate reference system (CRS)`. + """ + return self._crs + + @crs.setter + def crs(self, new_crs: CRS) -> None: + """Change the :term:`coordinate reference system (CRS)` of both datasets. + + Args: + new_crs: New :term:`coordinate reference system (CRS)`. + """ + self._crs = new_crs + self.datasets[0].crs = new_crs + self.datasets[1].crs = new_crs + + @property + def res(self) -> float: + """Resolution of both datasets in units of CRS. + + Returns: + Resolution of both datasets. + """ + return self._res + + @res.setter + def res(self, new_res: float) -> None: + """Change the resolution of both datasets. + + Args: + new_res: New resolution. + """ + self._res = new_res + self.datasets[0].res = new_res + self.datasets[1].res = new_res + class UnionDataset(GeoDataset): """Dataset representing the union of two GeoDatasets. @@ -970,23 +998,9 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError("UnionDataset only supports GeoDatasets") - self._crs = dataset1.crs + self.crs = dataset1.crs self.res = dataset1.res - # Force dataset2 to have the same CRS/res as dataset1 - if dataset1.crs != dataset2.crs: - print( - f"Converting {dataset2.__class__.__name__} CRS from " - f"{dataset2.crs} to {dataset1.crs}" - ) - dataset2.crs = dataset1.crs - if dataset1.res != dataset2.res: - print( - f"Converting {dataset2.__class__.__name__} resolution from " - f"{dataset2.res} to {dataset1.res}" - ) - dataset2.res = dataset1.res - # Merge dataset indices into a single index self._merge_dataset_indices() @@ -1040,3 +1054,43 @@ def __str__(self) -> str: type: UnionDataset bbox: {self.bounds} size: {len(self)}""" + + @property + def crs(self) -> CRS: + """:term:`coordinate reference system (CRS)` of both datasets. + + Returns: + The :term:`coordinate reference system (CRS)`. + """ + return self._crs + + @crs.setter + def crs(self, new_crs: CRS) -> None: + """Change the :term:`coordinate reference system (CRS)` of both datasets. + + Args: + new_crs: New :term:`coordinate reference system (CRS)`. + """ + self._crs = new_crs + self.datasets[0].crs = new_crs + self.datasets[1].crs = new_crs + + @property + def res(self) -> float: + """Resolution of both datasets in units of CRS. + + Returns: + The resolution of both datasets. + """ + return self._res + + @res.setter + def res(self, new_res: float) -> None: + """Change the resolution of both datasets. + + Args: + new_res: New resolution. + """ + self._res = new_res + self.datasets[0].res = new_res + self.datasets[1].res = new_res