Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Oct 4, 2022
1 parent 85fd452 commit 54c887d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import Any, Dict, Type, cast

import pytest
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datamodules import COWCCountingDataModule, CycloneDataModule
from torchgeo.trainers import RegressionTask
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def draw_semantic_segmentation_masks(
classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8))
class_masks = mask == classes[:, None, None]
img = draw_segmentation_masks(
image=image, masks=class_masks, alpha=alpha, colors=colors
image=image.uint8(), masks=class_masks, alpha=alpha, colors=colors
)
img = img.permute((1, 2, 0)).numpy().astype(np.uint8)
return cast("np.typing.NDArray[np.uint8]", img)
Expand Down

0 comments on commit 54c887d

Please sign in to comment.