diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 631897bb7f2..cd6d2b36a24 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -31,6 +31,7 @@ merge_samples, percentile_normalization, stack_samples, + unbind_samples, working_dir, ) @@ -457,7 +458,7 @@ def samples(self) -> List[Dict[str, Any]]: }, ] - def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: sample = stack_samples(samples) assert sample["image"].size() == torch.Size( # type: ignore[attr-defined] [2, 3] @@ -468,6 +469,13 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: ) assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)] + new_samples = unbind_samples(sample) + for i in range(2): + assert torch.allclose( # type: ignore[attr-defined] + samples[i]["image"], new_samples[i]["image"] + ) + assert samples[i]["crs"] == new_samples[i]["crs"] + def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: sample = concat_samples(samples) assert sample["image"].size() == torch.Size([6]) # type: ignore[attr-defined] @@ -500,7 +508,7 @@ def samples(self) -> List[Dict[str, Any]]: }, ] - def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None: sample = stack_samples(samples) assert sample["image"].size() == torch.Size( # type: ignore[attr-defined] [1, 3] @@ -515,6 +523,16 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None: assert sample["crs1"] == [CRS.from_epsg(2000)] assert sample["crs2"] == [CRS.from_epsg(2001)] + new_samples = unbind_samples(sample) + assert torch.allclose( # type: ignore[attr-defined] + samples[0]["image"], new_samples[0]["image"] + ) + assert samples[0]["crs1"] == new_samples[0]["crs1"] + assert torch.allclose( # type: ignore[attr-defined] + samples[1]["mask"], new_samples[0]["mask"] + ) + assert samples[1]["crs2"] == new_samples[0]["crs2"] + def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None: sample = concat_samples(samples) assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined] diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 83f7b749144..091d6f16c37 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -448,6 +448,26 @@ def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List return collated +def _dict_list_to_list_dict(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: + """Convert a dictionary of lists to a list of dictionaries. + + Args: + sample: a dictionary of lists + + Returns: + a list of dictionaries + + .. versionadded:: 0.2 + """ + uncollated: List[Dict[Any, Any]] = [ + {} for _ in range(max(map(len, sample.values()))) + ] + for key, values in sample.items(): + for i, value in enumerate(values): + uncollated[i][key] = value + return uncollated + + def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]: """Stack a list of samples along a new axis. @@ -532,15 +552,10 @@ def unbind_samples(sample: Dict[Any, Sequence[Any]]) -> List[Dict[Any, Any]]: .. versionadded:: 0.2 """ - uncollated: List[Dict[Any, Any]] = [{}] * max(map(len, sample.values())) for key, values in sample.items(): if isinstance(values, Tensor): - for i, value in enumerate(torch.unbind(values)): - uncollated[i][key] = value - else: - for i, value in enumerate(values): - uncollated[i][key] = value - return uncollated + sample[key] = torch.unbind(values) + return _dict_list_to_list_dict(sample) def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]