diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 564524f9479..eed76054ea4 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -47,9 +47,9 @@ class SustainBenchCropYield(NonGeoDataset): valid_countries = ["usa", "brazil", "argentina"] - md5 = "c2794e59512c897d9bea77b112848122" + md5 = "362bad07b51a1264172b8376b39d1fc9" - url = "https://drive.google.com/file/d/1odwkI1hiE5rMZ4VfM0hOXzlFR4NbhrfU/view?usp=share_link" # noqa: E501 + url = "https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link" # noqa: E501 dir = "soybeans" @@ -96,7 +96,38 @@ def __init__( self.checksum = checksum self._verify() - self.collection = self.retrieve_collection() + + self.images = [] + self.features = [] + + for country in self.countries: + image_file_path = os.path.join( + self.root, self.dir, country, f"{self.split}_hists.npz" + ) + target_file_path = image_file_path.replace("_hists", "_yields") + years_file_path = image_file_path.replace("_hists", "_years") + ndvi_file_path = image_file_path.replace("_hists", "_ndvi") + + npz_file = np.load(image_file_path)["data"] + target_npz_file = np.load(target_file_path)["data"] + year_npz_file = np.load(years_file_path)["data"] + ndvi_npz_file = np.load(ndvi_file_path)["data"] + num_data_points = npz_file.shape[0] + for idx in range(num_data_points): + sample = npz_file[idx] + sample = torch.from_numpy(sample).permute(2, 0, 1).to(torch.float32) + self.images.append(sample) + + target = target_npz_file[idx] + year = year_npz_file[idx] + ndvi = ndvi_npz_file[idx] + + features = { + "label": torch.tensor(target).to(torch.float32), + "year": torch.tensor(int(year)), + "ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32), + } + self.features.append(features) def __len__(self) -> int: """Return the number of data points in the dataset. @@ -104,7 +135,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.collection) + return len(self.images) def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -115,76 +146,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - input_file_path, sample_idx = self.collection[index] - - sample: dict[str, Tensor] = { - "image": self._load_image(input_file_path, sample_idx) - } - sample.update(self._load_features(input_file_path, sample_idx)) + sample: dict[str, Tensor] = {"image": self.images[index]} + sample.update(self.features[index]) if self.transforms is not None: sample = self.transforms(sample) return sample - def _load_image(self, path: str, sample_idx: int) -> Tensor: - """Load input image. - - Args: - path: path to input npz collection - sample_idx: what sample to index from the npz collection - - Returns: - input image as tensor - """ - arr = np.load(path)["data"][sample_idx] - # return [channel, height, width] - return torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32) - - def _load_features(self, path: str, sample_idx: int) -> dict[str, Tensor]: - """Load features value. - - Args: - path: path to image npz collection - sample_idx: what sample to index from the npz collection - - Returns: - target regression value - """ - target_file_path = path.replace("_hists", "_yields") - target = np.load(target_file_path)["data"][sample_idx] - - years_file_path = path.replace("_hists", "_years") - year = int(np.load(years_file_path)["data"][sample_idx]) - - ndvi_file_path = path.replace("_hists", "_ndvi") - ndvi = np.load(ndvi_file_path)["data"][sample_idx] - - features = { - "label": torch.tensor(target).to(torch.float32), - "year": torch.tensor(year), - "ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32), - } - return features - - def retrieve_collection(self) -> list[tuple[str, int]]: - """Retrieve the collection. - - Returns: - path and index to dataset samples - """ - collection = [] - for country in self.countries: - file_path = os.path.join( - self.root, self.dir, country, f"{self.split}_hists.npz" - ) - npz_file = np.load(file_path) - num_data_points = npz_file["data"].shape[0] - for idx in range(num_data_points): - collection.append((file_path, idx)) - - return collection - def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist