Skip to content

Commit

Permalink
UCMerced: fix image shape bug (#1238)
Browse files Browse the repository at this point in the history
* fix resize bug in ucmerced

* remove unused test

* update tests

* move resize to dataset

* fix resize bug in ucmerced

* remove unused test

* update tests

* remove changes to datamodule

* Update torchgeo/datasets/ucmerced.py

Co-authored-by: Adam J. Stewart <[email protected]>

* fix docstring

* fix resize bug in ucmerced

* remove unused test

* update tests

* move resize to dataset

* fix resize bug in ucmerced

* remove unused test

* update tests

* remove changes to datamodule

* Update torchgeo/datasets/ucmerced.py

Co-authored-by: Adam J. Stewart <[email protected]>

* fix docstring

* update docstring

* update docstring x3

* remove Dict

* remove Tuple

* Fix base class docs

* Grammar fix

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
isaaccorley and adamjstewart authored Apr 16, 2023
1 parent e0f9ece commit 83a978a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ experiment:
datamodule:
root: "tests/data/ucmerced"
download: true
batch_size: 1
batch_size: 2
num_workers: 0
7 changes: 4 additions & 3 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Args:
index: index to return
Returns:
data and label at that index
"""
Expand All @@ -756,13 +757,13 @@ def __len__(self) -> int:
return len(self.imgs)

def _load_image(self, index: int) -> tuple[Tensor, Tensor]:
"""Load a single image and it's class label.
"""Load a single image and its class label.
Args:
index: index to return
Returns:
the image
the image class label
the image and class label
"""
img, label = ImageFolder.__getitem__(self, index)
array: "np.typing.NDArray[np.int_]" = np.array(img)
Expand Down
14 changes: 14 additions & 0 deletions torchgeo/datasets/ucmerced.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as F
from torch import Tensor

from .geo import NonGeoClassificationDataset
Expand Down Expand Up @@ -143,6 +144,19 @@ def __init__(
is_valid_file=is_in_split,
)

def _load_image(self, index: int) -> tuple[Tensor, Tensor]:
"""Load a single image and its class label.
Args:
index: index to return
Returns:
the image and class label
"""
img, label = super()._load_image(index)
img = F.resize(img, size=(256, 256), antialias=True)
return img, label

def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Expand Down

0 comments on commit 83a978a

Please sign in to comment.