diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index b4fe1993310..5ded6c05265 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional -import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit @@ -149,10 +148,3 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.TropicalCycloneWindEstimation.plot`. - - .. versionadded:: 0.4 - """ - return self.all_train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 955467048e3..1a5de89aee4 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional -import matplotlib.pyplot as plt import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose @@ -123,10 +122,3 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: num_workers=self.num_workers, shuffle=False, ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index fbe4ff09d23..e604eed767c 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple import kornia.augmentation as K -import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from einops import repeat @@ -197,10 +196,3 @@ def test_dataloader(self) -> DataLoader[Any]: return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.OSCD.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 3d3360f0008..e72d295acec 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Optional -import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit @@ -192,10 +191,3 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.SEN12MS.plot`. - - .. versionadded:: 0.4 - """ - return self.all_train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index c27c383284f..e12a3359d6d 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -232,7 +232,6 @@ def plot( .. versionadded:: 0.2 """ image, label = sample["image"], sample["label"] - image = image.permute((1, 2, 0)).numpy() showing_predictions = "prediction" in sample if showing_predictions: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 2335dd16060..1e7f43bbf87 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -645,7 +645,7 @@ def draw_semantic_segmentation_masks( classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8)) class_masks = mask == classes[:, None, None] img = draw_segmentation_masks( - image=image.uint8(), masks=class_masks, alpha=alpha, colors=colors + image=image.byte(), masks=class_masks, alpha=alpha, colors=colors ) img = img.permute((1, 2, 0)).numpy().astype(np.uint8) return cast("np.typing.NDArray[np.uint8]", img)