From 83a978a32dbc212b0b1475a3d509abba2c20d2c9 Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 16 Apr 2023 18:57:16 -0500 Subject: [PATCH] UCMerced: fix image shape bug (#1238) * fix resize bug in ucmerced * remove unused test * update tests * move resize to dataset * fix resize bug in ucmerced * remove unused test * update tests * remove changes to datamodule * Update torchgeo/datasets/ucmerced.py Co-authored-by: Adam J. Stewart * fix docstring * fix resize bug in ucmerced * remove unused test * update tests * move resize to dataset * fix resize bug in ucmerced * remove unused test * update tests * remove changes to datamodule * Update torchgeo/datasets/ucmerced.py Co-authored-by: Adam J. Stewart * fix docstring * update docstring * update docstring x3 * remove Dict * remove Tuple * Fix base class docs * Grammar fix --------- Co-authored-by: Adam J. Stewart --- tests/conf/ucmerced.yaml | 2 +- torchgeo/datasets/geo.py | 7 ++++--- torchgeo/datasets/ucmerced.py | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 1a2ddf8add4..3c544564ae8 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -11,5 +11,5 @@ experiment: datamodule: root: "tests/data/ucmerced" download: true - batch_size: 1 + batch_size: 2 num_workers: 0 diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f5ed1d3ec38..4cd54d5d9b3 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -736,6 +736,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Args: index: index to return + Returns: data and label at that index """ @@ -756,13 +757,13 @@ def __len__(self) -> int: return len(self.imgs) def _load_image(self, index: int) -> tuple[Tensor, Tensor]: - """Load a single image and it's class label. + """Load a single image and its class label. Args: index: index to return + Returns: - the image - the image class label + the image and class label """ img, label = ImageFolder.__getitem__(self, index) array: "np.typing.NDArray[np.int_]" = np.array(img) diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 145cda876c5..ce3966646c3 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np +import torchvision.transforms.functional as F from torch import Tensor from .geo import NonGeoClassificationDataset @@ -143,6 +144,19 @@ def __init__( is_valid_file=is_in_split, ) + def _load_image(self, index: int) -> tuple[Tensor, Tensor]: + """Load a single image and its class label. + + Args: + index: index to return + + Returns: + the image and class label + """ + img, label = super()._load_image(index) + img = F.resize(img, size=(256, 256), antialias=True) + return img, label + def _check_integrity(self) -> bool: """Check integrity of dataset.