diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 6f5b7003d20..e24fbe4612c 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -204,7 +204,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor: with Image.open(path) as img: images.append(np.array(img)) array: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0).astype(np.int_) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: