Skip to content

Commit

Permalink
Fix bug in unbind_samples, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 24, 2021
1 parent 65cc6df commit 9657a67
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
22 changes: 20 additions & 2 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
merge_samples,
percentile_normalization,
stack_samples,
unbind_samples,
working_dir,
)

Expand Down Expand Up @@ -457,7 +458,7 @@ def samples(self) -> List[Dict[str, Any]]:
},
]

def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = stack_samples(samples)
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
[2, 3]
Expand All @@ -468,6 +469,13 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
)
assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)]

new_samples = unbind_samples(sample)
for i in range(2):
assert torch.allclose( # type: ignore[attr-defined]
samples[i]["image"], new_samples[i]["image"]
)
assert samples[i]["crs"] == new_samples[i]["crs"]

def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = concat_samples(samples)
assert sample["image"].size() == torch.Size([6]) # type: ignore[attr-defined]
Expand Down Expand Up @@ -500,7 +508,7 @@ def samples(self) -> List[Dict[str, Any]]:
},
]

def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = stack_samples(samples)
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
[1, 3]
Expand All @@ -515,6 +523,16 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
assert sample["crs1"] == [CRS.from_epsg(2000)]
assert sample["crs2"] == [CRS.from_epsg(2001)]

new_samples = unbind_samples(sample)
assert torch.allclose( # type: ignore[attr-defined]
samples[0]["image"], new_samples[0]["image"]
)
assert samples[0]["crs1"] == new_samples[0]["crs1"]
assert torch.allclose( # type: ignore[attr-defined]
samples[1]["mask"], new_samples[0]["mask"]
)
assert samples[1]["crs2"] == new_samples[0]["crs2"]

def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = concat_samples(samples)
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
Expand Down
29 changes: 22 additions & 7 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,26 @@ def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List
return collated


def _dict_list_to_list_dict(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]:
"""Convert a dictionary of lists to a list of dictionaries.
Args:
sample: a dictionary of lists
Returns:
a list of dictionaries
.. versionadded:: 0.2
"""
uncollated: List[Dict[Any, Any]] = [
{} for _ in range(max(map(len, sample.values())))
]
for key, values in sample.items():
for i, value in enumerate(values):
uncollated[i][key] = value
return uncollated


def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
"""Stack a list of samples along a new axis.
Expand Down Expand Up @@ -532,15 +552,10 @@ def unbind_samples(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]:
.. versionadded:: 0.2
"""
uncollated: List[Dict[Any, Any]] = [{}] * max(map(len, sample.values()))
for key, values in sample.items():
if isinstance(values, Tensor):
for i, value in enumerate(torch.unbind(values)):
uncollated[i][key] = value
else:
for i, value in enumerate(values):
uncollated[i][key] = value
return uncollated
sample[key] = torch.unbind(values)
return _dict_list_to_list_dict(sample)


def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
Expand Down

0 comments on commit 9657a67

Please sign in to comment.