Skip to content

Commit

Permalink
Undo changes that break code
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Oct 14, 2022
1 parent e9b6435 commit e4a1c02
Show file tree
Hide file tree
Showing 6 changed files with 1 addition and 34 deletions.
8 changes: 0 additions & 8 deletions torchgeo/datamodules/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from sklearn.model_selection import GroupShuffleSplit
Expand Down Expand Up @@ -149,10 +148,3 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.TropicalCycloneWindEstimation.plot`.
.. versionadded:: 0.4
"""
return self.all_train_dataset.plot(*args, **kwargs)
8 changes: 0 additions & 8 deletions torchgeo/datamodules/deepglobelandcover.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
Expand Down Expand Up @@ -123,10 +122,3 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`.
.. versionadded:: 0.4
"""
return self.test_dataset.plot(*args, **kwargs)
8 changes: 0 additions & 8 deletions torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, List, Optional, Tuple

import kornia.augmentation as K
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from einops import repeat
Expand Down Expand Up @@ -197,10 +196,3 @@ def test_dataloader(self) -> DataLoader[Any]:
return DataLoader(
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.OSCD.plot`.
.. versionadded:: 0.4
"""
return self.test_dataset.plot(*args, **kwargs)
8 changes: 0 additions & 8 deletions torchgeo/datamodules/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from sklearn.model_selection import GroupShuffleSplit
Expand Down Expand Up @@ -192,10 +191,3 @@ def test_dataloader(self) -> DataLoader[Any]:
num_workers=self.num_workers,
shuffle=False,
)

def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.SEN12MS.plot`.
.. versionadded:: 0.4
"""
return self.all_train_dataset.plot(*args, **kwargs)
1 change: 0 additions & 1 deletion torchgeo/datasets/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def plot(
.. versionadded:: 0.2
"""
image, label = sample["image"], sample["label"]
image = image.permute((1, 2, 0)).numpy()

showing_predictions = "prediction" in sample
if showing_predictions:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def draw_semantic_segmentation_masks(
classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8))
class_masks = mask == classes[:, None, None]
img = draw_segmentation_masks(
image=image.uint8(), masks=class_masks, alpha=alpha, colors=colors
image=image.byte(), masks=class_masks, alpha=alpha, colors=colors
)
img = img.permute((1, 2, 0)).numpy().astype(np.uint8)
return cast("np.typing.NDArray[np.uint8]", img)
Expand Down

0 comments on commit e4a1c02

Please sign in to comment.