Skip to content

Commit

Permalink
Add own data to CDL dataset. (#429)
Browse files Browse the repository at this point in the history
* add own data

* data.py with cmap

* remove .img test data format
  • Loading branch information
nilsleh authored and adamjstewart committed Mar 19, 2022
1 parent cd125b1 commit c6f77e4
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 3 deletions.
Binary file modified tests/data/cdl/2020_30m_cdls.zip
Binary file not shown.
Binary file modified tests/data/cdl/2021_30m_cdls.zip
Binary file not shown.
69 changes: 69 additions & 0 deletions tests/data/cdl/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random
import shutil

import numpy as np
import rasterio

SIZE = 32

np.random.seed(0)
random.seed(0)


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2

Z = np.random.randint(size=(SIZE, SIZE), low=0, high=8)

src = rasterio.open(path, "w", **profile)
for i in range(1, profile["count"] + 1):
src.write(Z, i)


directories = ["2020_30m_cdls", "2021_30m_cdls"]
raster_extensions = [[".tif", ".tif.ovr"], [".tif", ".tif.ovr"]]


if __name__ == "__main__":

for dir, extensions in zip(directories, raster_extensions):
filename = dir + ".zip"

# Remove old data
if os.path.isdir(dir):
shutil.rmtree(dir)

os.makedirs(os.path.join(os.getcwd(), dir))

for e in extensions:
create_file(
os.path.join(dir, filename.replace(".zip", e)),
dtype="int8",
num_channels=1,
)

# Compress data
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", dir)

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")

shutil.rmtree(dir)
16 changes: 14 additions & 2 deletions tests/datasets/test_cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,21 @@ def dataset(
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.cdl, "download_url", download_url
)
cmap = {
0: (0, 0, 0, 0),
1: (255, 211, 0, 255),
2: (255, 38, 38, 255),
3: (0, 168, 228, 255),
4: (255, 158, 11, 255),
5: (38, 112, 0, 255),
6: (255, 255, 0, 255),
7: (0, 0, 0, 255),
8: (0, 0, 0, 255),
}
monkeypatch.setattr(CDL, "cmap", cmap) # type: ignore[attr-defined]
md5s = [
(2021, "0693f0bb10deb79c69bcafe4aa1635b7"),
(2020, "7695292902a8672d16ac034d4d560d84"),
(2021, "4618f054004110ea11b19541b4b9f734"),
(2020, "593a86e62e3dd44438d536dc2442c082"),
]
monkeypatch.setattr(CDL, "md5s", md5s) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import glob
import os
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple

from rasterio.crs import CRS

Expand Down Expand Up @@ -57,6 +57,8 @@ class CDL(RasterDataset):
(2008, "0610f2f17ab60a9fbb3baeb7543993a4"),
]

cmap: Dict[int, Tuple[int, int, int, int]] = {}

def __init__(
self,
root: str = "data",
Expand Down

0 comments on commit c6f77e4

Please sign in to comment.