From bbf6108516ea241d8b260c4bf9d18599f822c2ca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 18 May 2023 23:33:47 -0500 Subject: [PATCH] Fix reprojection issues (#1344) * Add tests to catch reprojection issues * More specific tests * Fix boundless reprojection * black/flake8 * Add getters and setters for all crs/res attributes * Actually fix reprojection bug * Fix tests * Get all tests passing * More specific tests * Remove aux files * Ignore *.aux.xml files * Fix mypy * Increase coverage * Use dtype properly * Increase coverage * Add newline --- .gitignore | 1 + tests/data/raster/data.py | 102 +++++++-- .../res_2_epsg_32631/res_2_epsg_32631.tif | Bin 0 -> 436 bytes .../res_2_epsg_4087/res_2_epsg_4087.tif | Bin 0 -> 453 bytes .../res_2_epsg_4326/res_2_epsg_4326.tif | Bin 0 -> 442 bytes .../res_4_epsg_32631/res_4_epsg_32631.tif | Bin 0 -> 388 bytes .../res_4_epsg_4087/res_4_epsg_4087.tif | Bin 0 -> 405 bytes .../res_4_epsg_4326/res_4_epsg_4326.tif | Bin 0 -> 394 bytes .../res_8_epsg_32631/res_8_epsg_32631.tif | Bin 0 -> 376 bytes .../res_8_epsg_4087/res_8_epsg_4087.tif | Bin 0 -> 393 bytes .../res_8_epsg_4326/res_8_epsg_4326.tif | Bin 0 -> 382 bytes tests/data/raster/test0.tif | Bin 421 -> 0 bytes tests/data/raster/uint16/corrupted.tif | 1 + tests/data/raster/uint16/uint16.tif | Bin 0 -> 517 bytes tests/data/raster/uint32/corrupted.tif | 1 + tests/data/raster/uint32/uint32.tif | Bin 0 -> 645 bytes tests/datasets/test_geo.py | 204 +++++++++++++----- torchgeo/datasets/geo.py | 176 +++++++++------ 18 files changed, 349 insertions(+), 136 deletions(-) create mode 100644 tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif create mode 100644 tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif create mode 100644 tests/data/raster/res_2_epsg_4326/res_2_epsg_4326.tif create mode 100644 tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif create mode 100644 tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif create mode 100644 tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif create mode 100644 tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif create mode 100644 tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif create mode 100644 tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif delete mode 100644 tests/data/raster/test0.tif create mode 100644 tests/data/raster/uint16/corrupted.tif create mode 100644 tests/data/raster/uint16/uint16.tif create mode 100644 tests/data/raster/uint32/corrupted.tif create mode 100644 tests/data/raster/uint32/uint32.tif diff --git a/.gitignore b/.gitignore index d490e65b1da..29e2b022a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ /output/ *.pdf /results/ +*.aux.xml # Spack .spack-env/ diff --git a/tests/data/raster/data.py b/tests/data/raster/data.py index cc00aafef9b..517649607ba 100755 --- a/tests/data/raster/data.py +++ b/tests/data/raster/data.py @@ -2,41 +2,99 @@ # Licensed under the MIT License. import os +from typing import Optional import numpy as np -import rasterio -import rasterio.transform -from torchvision.datasets.utils import calculate_md5 +import rasterio as rio +from rasterio.transform import from_bounds +from rasterio.warp import calculate_default_transform, reproject +RES = [2, 4, 8] +EPSG = [4087, 4326, 32631] +SIZE = 16 -def generate_test_data(fn: str) -> str: - """Creates test data with uint32 datatype. - Args: - fn (str): Filename to write +def write_raster( + res: int = RES[0], + epsg: int = EPSG[0], + dtype: str = "uint8", + path: Optional[str] = None, +) -> None: + """Write a raster file. - Returns: - str: md5 hash of created archive + Args: + res: Resolution. + epsg: EPSG of file. + dtype: Data type. + path: File path. """ + size = SIZE // res profile = { "driver": "GTiff", - "dtype": "uint32", + "dtype": dtype, "count": 1, - "crs": "epsg:4326", - "transform": rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1), - "height": 4, - "width": 4, - "compress": "lzw", - "predictor": 2, + "crs": f"epsg:{epsg}", + "transform": from_bounds(0, 0, SIZE, SIZE, size, size), + "height": size, + "width": size, + "nodata": 0, } - with rasterio.open(fn, "w", **profile) as f: - f.write(np.random.randint(0, 256, size=(1, 4, 4))) + if path is None: + name = f"res_{res}_epsg_{epsg}" + path = os.path.join(name, f"{name}.tif") + + directory = os.path.dirname(path) + os.makedirs(directory, exist_ok=True) + + with rio.open(path, "w", **profile) as f: + x = np.ones((1, size, size)) + f.write(x) + - md5: str = calculate_md5(fn) - return md5 +def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None: + """Reproject a raster file. + + Args: + res: Resolution. + src_epsg: EPSG of source file. + dst_epsg: EPSG of destination file. + """ + src_name = f"res_{res}_epsg_{src_epsg}" + src_path = os.path.join(src_name, f"{src_name}.tif") + with rio.open(src_path) as src: + dst_crs = f"epsg:{dst_epsg}" + transform, width, height = calculate_default_transform( + src.crs, dst_crs, src.width, src.height, *src.bounds + ) + profile = src.profile.copy() + profile.update( + {"crs": dst_crs, "transform": transform, "width": width, "height": height} + ) + dst_name = f"res_{res}_epsg_{dst_epsg}" + os.makedirs(dst_name, exist_ok=True) + dst_path = os.path.join(dst_name, f"{dst_name}.tif") + with rio.open(dst_path, "w", **profile) as dst: + reproject( + source=rio.band(src, 1), + destination=rio.band(dst, 1), + src_transform=src.transform, + src_crs=src.crs, + dst_transform=dst.transform, + dst_crs=dst.crs, + ) if __name__ == "__main__": - md5_hash = generate_test_data(os.path.join(os.getcwd(), "test0.tif")) - print(md5_hash) + for res in RES: + src_epsg = EPSG[0] + write_raster(res, src_epsg) + + for dst_epsg in EPSG[1:]: + reproject_raster(res, src_epsg, dst_epsg) + + for dtype in ["uint16", "uint32"]: + path = os.path.join(dtype, f"{dtype}.tif") + write_raster(dtype=dtype, path=path) + with open(os.path.join(dtype, "corrupted.tif"), "w") as f: + f.write("not a tif file\n") diff --git a/tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif b/tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif new file mode 100644 index 0000000000000000000000000000000000000000..caa023eed76a12d72d728942b4af8c71ba71e635 GIT binary patch literal 436 zcmebD)MDUZU|=Ur9FQFbWH&Bh zVgQ+80K_2Bq}{n>(VPGFFd8C?o2)vY+pX=w;wa7A-L7&`-vOqQ5$qdw1_l;jSb!ju zW@6dcE(TP@$g#1Vok5I&4H#IA41LNBJRrKfo&`*6Du=rVD_EE)=qrSV_$pN8=cOtb P8~W8iL~DSdOBnzF+=nKo literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif b/tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif new file mode 100644 index 0000000000000000000000000000000000000000..752611e87bb1f61384441ffaf679e3a3c149ab1d GIT binary patch literal 453 zcmebD)MDUZU|46^ELA@Fazb!0dZ3c4+9&Jeg?$r+nE^@fV3FU{Eh8Q4EjKJ7?9n#go%L( z$Tk3CB;WvL!)PQad`ttldPcBs*cljDfMEfGP@0KlW4k)g2u6;L?d%Nd3~a!_Vr1x3 zX5azQ-}zand$59qiGsdDcz#h%ih^rlX=X}haY;f`v&9 IFpMb!0IudEnE(I) literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif b/tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif new file mode 100644 index 0000000000000000000000000000000000000000..dd69237d92b1c8771afd17d2051fa315df42c4f6 GIT binary patch literal 388 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-+1&gea1@7?ce% zQyi)WXcL1BRGbOO4n|VX*UZDf3=}&D#7!+c3~WI984#~;XJ$|U(qcgKH?}h|$N|}5 zKz8F2CI*lh20#n~P1>DH7QNwjfYA_9++@}H+-_|b7Ds8`?sk=n`VKIaj9}leg98J^ z209%?GBB}hY!?G6V&vG^&dwmlzy=HyMut9R1|ASyUe5xiHI>8NgB2`H6!aBBLwpsg T^7B#^j1B#2Afh$E&_x9R9K|GE literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif b/tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif new file mode 100644 index 0000000000000000000000000000000000000000..c9ce35ee860058fd9c80530f9909a9bba4c39fb1 GIT binary patch literal 405 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-;angea1@7?ce% zQyi)WXcL1BlA2&7HeWLj12a(m91u6P@G!6e>1RN^zMYvt0Z5Ah&EMG0#GnskhXL7* zOPCm#fNTRGMgRc^2pdKtWbm^Lp!$IR2m6K{92g)r(CJ_?CYFuu>OeynIX1SlGpIAL z0Rx4Rp--8C2Sk78X93fi%Hi(83Kk{``U>IsML8)7u7#zUDVfD3iFqXo&XqZtc_~Gi P$%#2N5M?#M;6?=ioUtHW literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif b/tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif new file mode 100644 index 0000000000000000000000000000000000000000..eb2d7a9c66d6f2876bf9b9007ca06543cca61e99 GIT binary patch literal 394 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-+=5gea1@7?ce% zQyi*B0LYesiZcP(!AR=)nt2$Qfnv9SxT%GQfelE10^;@U%nS-ZS_x?WhISqXkiH}! zdt*Bj*!4v~cH z1ZW{JKo}XG2{3YOYzN9Ruz~q~%A6b<+ClQ%V48uUBe!Il?2;k}pc02=ArX$@?!gKc KCN;n?Mg;(ikReY1 literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif b/tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif new file mode 100644 index 0000000000000000000000000000000000000000..92d838ab2c2e91205ebbf1fc05c10ef8a4f439af GIT binary patch literal 376 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-+1&gea1@7?ce% zQyi*>1;~~`5)VdF&)3Yuzzmc>2gFS+JPd3=`WX>&87y}zHP#78dlo@zHba_1snATJdcMn#uFj3G~2o3R7sLIbv RRWLU6tAU8t07I9N5ddhRBwqjk literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif b/tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif new file mode 100644 index 0000000000000000000000000000000000000000..0b989058f165a8a61efe7162b9d2d4536724942c GIT binary patch literal 393 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-;angea1@7?ce% zQyi*>1;~~`QWK2C=4<9*UOeynIX1SlGpIAL z0Rx4Rp--8C2Sk78X93fi%Hi(83Kk{``U>IsML8)7u7#zUDVfD3iFqXo&XqZtc_~Gi Q$%#2N5M?#M;AUh502!7bUH||9 literal 0 HcmV?d00001 diff --git a/tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif b/tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif new file mode 100644 index 0000000000000000000000000000000000000000..aa0a7318bcc97fcfa1ef9a3f7387ad45f288550c GIT binary patch literal 382 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-+=5gea1@7?ce% zQyi*>1;~~`5)VdF&)3YuzzmeX1;kA)JPd3=`V$bZZ)avu0Mbf8^Eb5ffZdn`WN&O| z0=vEl$ZlN11Qs^{ssjPp0ttoF*M#h0G(?mzDF@c4WDirx2=*B}0|U_Ej6fTJ-T=}- z3xNT`$nZ>nkz->!P?muW%-pD0L7W1Y*rwf4ax@T5oBZm>#YKEM3KbB zplpzt;!ri(K(-8$8e1rv2`C8C9uWx5&Pyo_OK)W`y z^Duz)B>~wR+nK;VDFU();DbGc4Wkh<_*n?`j9{O!g9C;UXdBQQKoOuNz+hu!cqYKe zv9TQ}%fJTa_bGF7Y-k6`bAxFHhK}5lX|hX-9Dqt3mW4z(hPwwVSeVo>H1ISCnC#(U psIEWceIZpxaxG{5tOZ9qnm+MZHrpg|?r1FH320VW#F}Bk0svU*EC2ui diff --git a/tests/data/raster/uint16/corrupted.tif b/tests/data/raster/uint16/corrupted.tif new file mode 100644 index 00000000000..42e548ffea8 --- /dev/null +++ b/tests/data/raster/uint16/corrupted.tif @@ -0,0 +1 @@ +not a tif file diff --git a/tests/data/raster/uint16/uint16.tif b/tests/data/raster/uint16/uint16.tif new file mode 100644 index 0000000000000000000000000000000000000000..05e38bdc42ea6c15aba27fa1703e8102cb35580d GIT binary patch literal 517 zcmebD)MDUZU|~1fzko3ELNff*=&4v3ptco^7#^fMq{-_FdS0Hno$=5K6gV$cV&!+`9@ zB}@!VK(+x8BLN2}8%85Z;bR)W)iZ*9!_L6K0t^cfgwjkb8{5@^Mlf=0Y-eXsXJ7+* zjFF*FnSlpHf9Gcb)0)cR?!gKcCJOor;rT^5DGIKIrI{(2#U+V(B?``!IhlDWMVZNo OIW-VvH4rNXkOTmI_8{;8 literal 0 HcmV?d00001 diff --git a/tests/data/raster/uint32/corrupted.tif b/tests/data/raster/uint32/corrupted.tif new file mode 100644 index 00000000000..42e548ffea8 --- /dev/null +++ b/tests/data/raster/uint32/corrupted.tif @@ -0,0 +1 @@ +not a tif file diff --git a/tests/data/raster/uint32/uint32.tif b/tests/data/raster/uint32/uint32.tif new file mode 100644 index 0000000000000000000000000000000000000000..e42c41fc4706ce90c287eb460fa2280ab9cf31d8 GIT binary patch literal 645 zcmebD)MDUZU|8IVYy%)h0uE3%j7E~e$25SeX9W9(oq>S`7#1K1rI}bZwyOh;VC2}?&d#9Dzy|af zBSW7u0}qJ)&d&m-HI>8NgB2`H6!aCs^NVs)6kH2SGgC5)OA_-+6r3w_GV@Z3GLsW? PY9PvLfLcM?M=k~c1!y3F literal 0 HcmV?d00001 diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 259c583ebe2..cf4ef25d880 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -30,7 +30,7 @@ class CustomGeoDataset(GeoDataset): def __init__( self, bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), - crs: CRS = CRS.from_epsg(3005), + crs: CRS = CRS.from_epsg(4087), res: float = 1, ) -> None: super().__init__() @@ -74,7 +74,7 @@ def test_getitem(self, dataset: GeoDataset) -> None: def test_len(self, dataset: GeoDataset) -> None: assert len(dataset) == 1 - @pytest.mark.parametrize("crs", [CRS.from_epsg(3005), CRS.from_epsg(32616)]) + @pytest.mark.parametrize("crs", [CRS.from_epsg(4087), CRS.from_epsg(32631)]) def test_crs(self, dataset: GeoDataset, crs: CRS) -> None: dataset.crs = crs @@ -157,7 +157,7 @@ class TestRasterDataset: def naip(self, request: SubRequest) -> NAIP: root = os.path.join("tests", "data", "naip") bands = request.param[0] - crs = CRS.from_epsg(3005) + crs = CRS.from_epsg(4087) transforms = nn.Identity() cache = request.param[1] return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache) @@ -178,11 +178,6 @@ def sentinel(self, request: SubRequest) -> Sentinel2: cache = request.param[1] return Sentinel2(root, bands=bands, transforms=transforms, cache=cache) - @pytest.fixture() - def custom_dtype_ds(self) -> RasterDataset: - root = os.path.join("tests", "data", "raster") - return RasterDataset(root) - def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] assert isinstance(x, dict) @@ -197,8 +192,11 @@ def test_getitem_separate_files(self, sentinel: Sentinel2) -> None: assert isinstance(x["image"], torch.Tensor) assert len(sentinel.bands) == x["image"].shape[0] - def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None: - x = custom_dtype_ds[custom_dtype_ds.bounds] + @pytest.mark.parametrize("dtype", ["uint16", "uint32"]) + def test_getitem_uint_dtype(self, dtype: str) -> None: + root = os.path.join("tests", "data", "raster", dtype) + ds = RasterDataset(root) + x = ds[ds.bounds] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert x["image"].dtype == torch.float32 @@ -377,14 +375,15 @@ def test_str(self, dataset: NonGeoClassificationDataset) -> None: class TestIntersectionDataset: @pytest.fixture(scope="class") def dataset(self) -> IntersectionDataset: - ds1 = CustomGeoDataset() - ds2 = CustomGeoDataset() + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326")) transforms = nn.Identity() return IntersectionDataset(ds1, ds2, transforms=transforms) def test_getitem(self, dataset: IntersectionDataset) -> None: - query = BoundingBox(0, 1, 2, 3, 4, 5) - assert dataset[query] == {"index": query} + query = dataset.bounds + sample = dataset[query] + assert isinstance(sample["image"], torch.Tensor) def test_len(self, dataset: IntersectionDataset) -> None: assert len(dataset) == 1 @@ -403,27 +402,69 @@ def test_nongeo_dataset(self) -> None: ): IntersectionDataset(ds1, ds2) # type: ignore[arg-type] - def test_different_crs(self) -> None: - ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005)) - ds2 = CustomGeoDataset( - BoundingBox( - -3547229.913123814, - 6360089.518213182, - -3547229.913123814, - 6360089.518213182, - -3547229.913123814, - 6360089.518213182, - ), - crs=CRS.from_epsg(32616), - ) + def test_different_crs_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 1 - - def test_different_res(self) -> None: - ds1 = CustomGeoDataset(res=1) - ds2 = CustomGeoDataset(res=2) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = (ds1 & ds2) & ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = ds1 & (ds2 & ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) ds = IntersectionDataset(ds1, ds2) - assert len(ds) == 1 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = (ds1 & ds2) & ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = ds1 & (ds2 & ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1 + assert isinstance(sample["image"], torch.Tensor) def test_no_overlap(self) -> None: ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5)) @@ -433,7 +474,7 @@ def test_no_overlap(self) -> None: IntersectionDataset(ds1, ds2) def test_invalid_query(self, dataset: IntersectionDataset) -> None: - query = BoundingBox(0, 0, 0, 0, 0, 0) + query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( IndexError, match="query: .* not found in index with bounds:" ): @@ -443,14 +484,15 @@ def test_invalid_query(self, dataset: IntersectionDataset) -> None: class TestUnionDataset: @pytest.fixture(scope="class") def dataset(self) -> UnionDataset: - ds1 = CustomGeoDataset(bounds=BoundingBox(0, 1, 0, 1, 0, 1)) - ds2 = CustomGeoDataset(bounds=BoundingBox(2, 3, 2, 3, 2, 3)) + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326")) transforms = nn.Identity() return UnionDataset(ds1, ds2, transforms=transforms) def test_getitem(self, dataset: UnionDataset) -> None: - query = BoundingBox(0, 1, 0, 1, 0, 1) - assert dataset[query] == {"index": query} + query = dataset.bounds + sample = dataset[query] + assert isinstance(sample["image"], torch.Tensor) def test_len(self, dataset: UnionDataset) -> None: assert len(dataset) == 2 @@ -461,6 +503,76 @@ def test_str(self, dataset: UnionDataset) -> None: assert "bbox: BoundingBox" in out assert "size: 2" in out + def test_different_crs_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds = UnionDataset(ds1, ds2) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == 1 + assert len(ds) == 2 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = (ds1 | ds2) | ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_crs_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631")) + ds = ds1 | (ds2 | ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds = UnionDataset(ds1, ds2) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds.res == 2 + assert len(ds1) == len(ds2) == 1 + assert len(ds) == 2 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_12_3(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = (ds1 | ds2) | ds3 + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + + def test_different_res_1_23(self) -> None: + ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087")) + ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087")) + ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087")) + ds = ds1 | (ds2 | ds3) + sample = ds[ds.bounds] + assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087) + assert ds1.res == ds2.res == ds3.res == ds.res == 2 + assert len(ds1) == len(ds2) == len(ds3) == 1 + assert len(ds) == 3 + assert isinstance(sample["image"], torch.Tensor) + def test_nongeo_dataset(self) -> None: ds1 = CustomNonGeoDataset() ds2 = CustomNonGeoDataset() @@ -473,22 +585,8 @@ def test_nongeo_dataset(self) -> None: with pytest.raises(ValueError, match=msg): UnionDataset(ds3, ds1) # type: ignore[arg-type] - def test_different_crs(self) -> None: - ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005)) - ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616)) - ds = UnionDataset(ds1, ds2) - assert ds.crs == ds1.crs - assert len(ds) == 2 - - def test_different_res(self) -> None: - ds1 = CustomGeoDataset(res=1) - ds2 = CustomGeoDataset(res=2) - ds = UnionDataset(ds1, ds2) - assert ds.res == ds1.res - assert len(ds) == 2 - def test_invalid_query(self, dataset: UnionDataset) -> None: - query = BoundingBox(4, 5, 4, 5, 4, 5) + query = BoundingBox(-1, -1, -1, -1, -1, -1) with pytest.raises( IndexError, match="query: .* not found in index with bounds:" ): diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d4a81f73b46..dc2b5fa1388 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -23,7 +23,6 @@ from rasterio.crs import CRS from rasterio.io import DatasetReader from rasterio.vrt import WarpedVRT -from rasterio.windows import from_bounds from rtree.index import Index, Property from torch import Tensor from torch.utils.data import Dataset @@ -73,9 +72,8 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): dataset = landsat7 | landsat8 """ - #: Resolution of the dataset in units of CRS. - res: float - _crs: CRS + _crs = CRS.from_epsg(4326) + _res = 0.0 # NOTE: according to the Python docs: # @@ -213,12 +211,10 @@ def bounds(self) -> BoundingBox: @property def crs(self) -> CRS: - """:term:`coordinate reference system (CRS)` for the dataset. + """:term:`coordinate reference system (CRS)` of the dataset. Returns: - the :term:`coordinate reference system (CRS)` - - .. versionadded:: 0.2 + The :term:`coordinate reference system (CRS)`. """ return self._crs @@ -229,17 +225,16 @@ def crs(self, new_crs: CRS) -> None: If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index. Args: - new_crs: new :term:`coordinate reference system (CRS)` - - .. versionadded:: 0.2 + new_crs: New :term:`coordinate reference system (CRS)`. """ - if new_crs == self._crs: + if new_crs == self.crs: return + print(f"Converting {self.__class__.__name__} CRS from {self.crs} to {new_crs}") new_index = Index(interleaved=False, properties=Property(dimension=3)) project = pyproj.Transformer.from_crs( - pyproj.CRS(str(self._crs)), pyproj.CRS(str(new_crs)), always_xy=True + pyproj.CRS(str(self.crs)), pyproj.CRS(str(new_crs)), always_xy=True ).transform for hit in self.index.intersection(self.index.bounds, objects=True): old_minx, old_maxx, old_miny, old_maxy, mint, maxt = hit.bounds @@ -252,6 +247,28 @@ def crs(self, new_crs: CRS) -> None: self._crs = new_crs self.index = new_index + @property + def res(self) -> float: + """Resolution of the dataset in units of CRS. + + Returns: + The resolution of the dataset. + """ + return self._res + + @res.setter + def res(self, new_res: float) -> None: + """Change the resolution of a GeoDataset. + + Args: + new_res: New resolution. + """ + if new_res == self.res: + return + + print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}") + self._res = new_res + class RasterDataset(GeoDataset): """Abstract base class for :class:`GeoDataset` stored as raster files.""" @@ -399,7 +416,7 @@ def __init__( raise AssertionError(msg) self._crs = cast(CRS, crs) - self.res = cast(float, res) + self._res = cast(float, res) def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -477,22 +494,7 @@ def _merge_files( vrt_fhs = [self._load_warp_file(fp) for fp in filepaths] bounds = (query.minx, query.miny, query.maxx, query.maxy) - if len(vrt_fhs) == 1: - src = vrt_fhs[0] - out_width = round((query.maxx - query.minx) / self.res) - out_height = round((query.maxy - query.miny) / self.res) - count = len(band_indexes) if band_indexes else src.count - out_shape = (count, out_height, out_width) - dest = src.read( - indexes=band_indexes, - out_shape=out_shape, - window=from_bounds(*bounds, src.transform), - boundless=True, - ) - else: - dest, _ = rasterio.merge.merge( - vrt_fhs, bounds, self.res, indexes=band_indexes - ) + dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes) # fix numpy dtypes which are not supported by pytorch tensors if dest.dtype == np.uint16: @@ -574,7 +576,6 @@ def __init__( super().__init__(transforms) self.root = root - self.res = res self.label_name = label_name # Populate the dataset index @@ -605,6 +606,7 @@ def __init__( raise FileNotFoundError(msg) self._crs = crs + self._res = res def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -844,23 +846,9 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError("IntersectionDataset only supports GeoDatasets") - self._crs = dataset1.crs + self.crs = dataset1.crs self.res = dataset1.res - # Force dataset2 to have the same CRS/res as dataset1 - if dataset1.crs != dataset2.crs: - print( - f"Converting {dataset2.__class__.__name__} CRS from " - f"{dataset2.crs} to {dataset1.crs}" - ) - dataset2.crs = dataset1.crs - if dataset1.res != dataset2.res: - print( - f"Converting {dataset2.__class__.__name__} resolution from " - f"{dataset2.res} to {dataset1.res}" - ) - dataset2.res = dataset1.res - # Merge dataset indices into a single index self._merge_dataset_indices() @@ -917,6 +905,46 @@ def __str__(self) -> str: bbox: {self.bounds} size: {len(self)}""" + @property + def crs(self) -> CRS: + """:term:`coordinate reference system (CRS)` of both datasets. + + Returns: + The :term:`coordinate reference system (CRS)`. + """ + return self._crs + + @crs.setter + def crs(self, new_crs: CRS) -> None: + """Change the :term:`coordinate reference system (CRS)` of both datasets. + + Args: + new_crs: New :term:`coordinate reference system (CRS)`. + """ + self._crs = new_crs + self.datasets[0].crs = new_crs + self.datasets[1].crs = new_crs + + @property + def res(self) -> float: + """Resolution of both datasets in units of CRS. + + Returns: + Resolution of both datasets. + """ + return self._res + + @res.setter + def res(self, new_res: float) -> None: + """Change the resolution of both datasets. + + Args: + new_res: New resolution. + """ + self._res = new_res + self.datasets[0].res = new_res + self.datasets[1].res = new_res + class UnionDataset(GeoDataset): """Dataset representing the union of two GeoDatasets. @@ -970,23 +998,9 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError("UnionDataset only supports GeoDatasets") - self._crs = dataset1.crs + self.crs = dataset1.crs self.res = dataset1.res - # Force dataset2 to have the same CRS/res as dataset1 - if dataset1.crs != dataset2.crs: - print( - f"Converting {dataset2.__class__.__name__} CRS from " - f"{dataset2.crs} to {dataset1.crs}" - ) - dataset2.crs = dataset1.crs - if dataset1.res != dataset2.res: - print( - f"Converting {dataset2.__class__.__name__} resolution from " - f"{dataset2.res} to {dataset1.res}" - ) - dataset2.res = dataset1.res - # Merge dataset indices into a single index self._merge_dataset_indices() @@ -1040,3 +1054,43 @@ def __str__(self) -> str: type: UnionDataset bbox: {self.bounds} size: {len(self)}""" + + @property + def crs(self) -> CRS: + """:term:`coordinate reference system (CRS)` of both datasets. + + Returns: + The :term:`coordinate reference system (CRS)`. + """ + return self._crs + + @crs.setter + def crs(self, new_crs: CRS) -> None: + """Change the :term:`coordinate reference system (CRS)` of both datasets. + + Args: + new_crs: New :term:`coordinate reference system (CRS)`. + """ + self._crs = new_crs + self.datasets[0].crs = new_crs + self.datasets[1].crs = new_crs + + @property + def res(self) -> float: + """Resolution of both datasets in units of CRS. + + Returns: + The resolution of both datasets. + """ + return self._res + + @res.setter + def res(self, new_res: float) -> None: + """Change the resolution of both datasets. + + Args: + new_res: New resolution. + """ + self._res = new_res + self.datasets[0].res = new_res + self.datasets[1].res = new_res