diff --git a/tests/data/cdl/2020_30m_cdls.zip b/tests/data/cdl/2020_30m_cdls.zip index 79759c11eee..3254477fcbd 100644 Binary files a/tests/data/cdl/2020_30m_cdls.zip and b/tests/data/cdl/2020_30m_cdls.zip differ diff --git a/tests/data/cdl/2021_30m_cdls.zip b/tests/data/cdl/2021_30m_cdls.zip index aebae0d8c29..0f7fb857164 100644 Binary files a/tests/data/cdl/2021_30m_cdls.zip and b/tests/data/cdl/2021_30m_cdls.zip differ diff --git a/tests/data/cdl/data.py b/tests/data/cdl/data.py new file mode 100644 index 00000000000..b76614fae10 --- /dev/null +++ b/tests/data/cdl/data.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import random +import shutil + +import numpy as np +import rasterio + +SIZE = 32 + +np.random.seed(0) +random.seed(0) + + +def create_file(path: str, dtype: str, num_channels: int) -> None: + profile = {} + profile["driver"] = "GTiff" + profile["dtype"] = dtype + profile["count"] = num_channels + profile["crs"] = "epsg:4326" + profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile["height"] = SIZE + profile["width"] = SIZE + profile["compress"] = "lzw" + profile["predictor"] = 2 + + Z = np.random.randint(size=(SIZE, SIZE), low=0, high=8) + + src = rasterio.open(path, "w", **profile) + for i in range(1, profile["count"] + 1): + src.write(Z, i) + + +directories = ["2020_30m_cdls", "2021_30m_cdls"] +raster_extensions = [[".tif", ".tif.ovr"], [".tif", ".tif.ovr"]] + + +if __name__ == "__main__": + + for dir, extensions in zip(directories, raster_extensions): + filename = dir + ".zip" + + # Remove old data + if os.path.isdir(dir): + shutil.rmtree(dir) + + os.makedirs(os.path.join(os.getcwd(), dir)) + + for e in extensions: + create_file( + os.path.join(dir, filename.replace(".zip", e)), + dtype="int8", + num_channels=1, + ) + + # Compress data + shutil.make_archive(filename.replace(".zip", ""), "zip", ".", dir) + + # Compute checksums + with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{filename}: {md5}") + + shutil.rmtree(dir) diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 0897bfd1520..1ef997018d1 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -31,9 +31,21 @@ def dataset( monkeypatch.setattr( # type: ignore[attr-defined] torchgeo.datasets.cdl, "download_url", download_url ) + cmap = { + 0: (0, 0, 0, 0), + 1: (255, 211, 0, 255), + 2: (255, 38, 38, 255), + 3: (0, 168, 228, 255), + 4: (255, 158, 11, 255), + 5: (38, 112, 0, 255), + 6: (255, 255, 0, 255), + 7: (0, 0, 0, 255), + 8: (0, 0, 0, 255), + } + monkeypatch.setattr(CDL, "cmap", cmap) # type: ignore[attr-defined] md5s = [ - (2021, "0693f0bb10deb79c69bcafe4aa1635b7"), - (2020, "7695292902a8672d16ac034d4d560d84"), + (2021, "4618f054004110ea11b19541b4b9f734"), + (2020, "593a86e62e3dd44438d536dc2442c082"), ] monkeypatch.setattr(CDL, "md5s", md5s) # type: ignore[attr-defined] url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip") diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 16ca304b914..6e016a7e712 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -5,7 +5,7 @@ import glob import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple from rasterio.crs import CRS @@ -57,6 +57,8 @@ class CDL(RasterDataset): (2008, "0610f2f17ab60a9fbb3baeb7543993a4"), ] + cmap: Dict[int, Tuple[int, int, int, int]] = {} + def __init__( self, root: str = "data",