Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to SustainBenchCropYield dataset #1756

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 37 additions & 68 deletions torchgeo/datasets/sustainbench_crop_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class SustainBenchCropYield(NonGeoDataset):

valid_countries = ["usa", "brazil", "argentina"]

md5 = "c2794e59512c897d9bea77b112848122"
md5 = "362bad07b51a1264172b8376b39d1fc9"

url = "https://drive.google.com/file/d/1odwkI1hiE5rMZ4VfM0hOXzlFR4NbhrfU/view?usp=share_link" # noqa: E501
url = "https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link" # noqa: E501

dir = "soybeans"

Expand Down Expand Up @@ -96,15 +96,46 @@ def __init__(
self.checksum = checksum

self._verify()
self.collection = self.retrieve_collection()

self.images = []
self.features = []

for country in self.countries:
image_file_path = os.path.join(
self.root, self.dir, country, f"{self.split}_hists.npz"
)
target_file_path = image_file_path.replace("_hists", "_yields")
years_file_path = image_file_path.replace("_hists", "_years")
ndvi_file_path = image_file_path.replace("_hists", "_ndvi")

npz_file = np.load(image_file_path)["data"]
target_npz_file = np.load(target_file_path)["data"]
year_npz_file = np.load(years_file_path)["data"]
ndvi_npz_file = np.load(ndvi_file_path)["data"]
num_data_points = npz_file.shape[0]
for idx in range(num_data_points):
sample = npz_file[idx]
sample = torch.from_numpy(sample).permute(2, 0, 1).to(torch.float32)
self.images.append(sample)

target = target_npz_file[idx]
year = year_npz_file[idx]
ndvi = ndvi_npz_file[idx]

features = {
"label": torch.tensor(target).to(torch.float32),
"year": torch.tensor(int(year)),
"ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32),
}
self.features.append(features)

def __len__(self) -> int:
"""Return the number of data points in the dataset.

Returns:
length of the dataset
"""
return len(self.collection)
return len(self.images)

def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Expand All @@ -115,76 +146,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Returns:
data and label at that index
"""
input_file_path, sample_idx = self.collection[index]

sample: dict[str, Tensor] = {
"image": self._load_image(input_file_path, sample_idx)
}
sample.update(self._load_features(input_file_path, sample_idx))
sample: dict[str, Tensor] = {"image": self.images[index]}
sample.update(self.features[index])

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def _load_image(self, path: str, sample_idx: int) -> Tensor:
"""Load input image.

Args:
path: path to input npz collection
sample_idx: what sample to index from the npz collection

Returns:
input image as tensor
"""
arr = np.load(path)["data"][sample_idx]
# return [channel, height, width]
return torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32)

def _load_features(self, path: str, sample_idx: int) -> dict[str, Tensor]:
"""Load features value.

Args:
path: path to image npz collection
sample_idx: what sample to index from the npz collection

Returns:
target regression value
"""
target_file_path = path.replace("_hists", "_yields")
target = np.load(target_file_path)["data"][sample_idx]

years_file_path = path.replace("_hists", "_years")
year = int(np.load(years_file_path)["data"][sample_idx])

ndvi_file_path = path.replace("_hists", "_ndvi")
ndvi = np.load(ndvi_file_path)["data"][sample_idx]

features = {
"label": torch.tensor(target).to(torch.float32),
"year": torch.tensor(year),
"ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32),
}
return features

def retrieve_collection(self) -> list[tuple[str, int]]:
"""Retrieve the collection.

Returns:
path and index to dataset samples
"""
collection = []
for country in self.countries:
file_path = os.path.join(
self.root, self.dir, country, f"{self.split}_hists.npz"
)
npz_file = np.load(file_path)
num_data_points = npz_file["data"].shape[0]
for idx in range(num_data_points):
collection.append((file_path, idx))

return collection

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
Expand Down