Skip to content

Commit

Permalink
returning the centroid of each cell in the sample dict (#1240)
Browse files Browse the repository at this point in the history
* returning the centroid of each cell in the sample dict

* updating usavars test to reflect centroid location as part of each sample

* returning the centroid of each cell in the sample dict

* updating usavars test to reflect centroid location as part of each sample

* separating lat and lon in each sample

* separating lat and lon in each sample
  • Loading branch information
estherrolf authored Apr 17, 2023
1 parent 83a978a commit 8d78b47
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/datasets/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def test_getitem(self, dataset: USAVars) -> None:
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 3
assert len(x.keys()) == 2 # image, labels
assert len(x.keys()) == 4 # image, labels, centroid_lat, centroid_lon
assert x["image"].shape[0] == 4 # R, G, B, Inf
assert len(dataset.labels) == len(x["labels"])
assert len(x["centroid_lat"]) == 1
assert len(x["centroid_lon"]) == 1

def test_len(self, dataset: USAVars) -> None:
if dataset.split == "train":
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datasets/usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
[self.label_dfs[lab].loc[id_][lab] for lab in self.labels]
),
"image": self._load_image(os.path.join(self.root, "uar", tif_file)),
"centroid_lat": Tensor([self.label_dfs[self.labels[0]].loc[id_]["lat"]]),
"centroid_lon": Tensor([self.label_dfs[self.labels[0]].loc[id_]["lon"]]),
}

if self.transforms is not None:
Expand Down

0 comments on commit 8d78b47

Please sign in to comment.