diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 1a2ddf8add4..3c544564ae8 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -11,5 +11,5 @@ experiment: datamodule: root: "tests/data/ucmerced" download: true - batch_size: 1 + batch_size: 2 num_workers: 0 diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f5ed1d3ec38..4cd54d5d9b3 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -736,6 +736,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Args: index: index to return + Returns: data and label at that index """ @@ -756,13 +757,13 @@ def __len__(self) -> int: return len(self.imgs) def _load_image(self, index: int) -> tuple[Tensor, Tensor]: - """Load a single image and it's class label. + """Load a single image and its class label. Args: index: index to return + Returns: - the image - the image class label + the image and class label """ img, label = ImageFolder.__getitem__(self, index) array: "np.typing.NDArray[np.int_]" = np.array(img) diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 145cda876c5..ce3966646c3 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np +import torchvision.transforms.functional as F from torch import Tensor from .geo import NonGeoClassificationDataset @@ -143,6 +144,19 @@ def __init__( is_valid_file=is_in_split, ) + def _load_image(self, index: int) -> tuple[Tensor, Tensor]: + """Load a single image and its class label. + + Args: + index: index to return + + Returns: + the image and class label + """ + img, label = super()._load_image(index) + img = F.resize(img, size=(256, 256), antialias=True) + return img, label + def _check_integrity(self) -> bool: """Check integrity of dataset.