Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 14, 2022
1 parent a7d8a5a commit 6fb67cd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion flash/core/data/utilities/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _load_image_from_image(file, drop_alpha: bool = True):


def _load_image_from_numpy(file):
return Image.fromarray(np.load(file).astype("uint8"), "RGB")
return Image.fromarray(np.load(file).astype("uint8")).convert("RGB")


def _load_spectrogram_from_image(file):
Expand Down
8 changes: 4 additions & 4 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def from_files(
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8"))
>>> rand_mask= np.random.randint(0, 10, (64, 64), dtype="uint8")
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)]
>>> _ = [np.save(f"mask_{i}.npy", rand_mask) for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]
.. doctest::
Expand All @@ -126,7 +126,7 @@ def from_files(
>>> from flash.image import SemanticSegmentation, SemanticSegmentationData
>>> datamodule = SemanticSegmentationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["mask_1.png", "mask_2.png", "mask_3.png"],
... train_targets=["mask_1.npy", "mask_2.npy", "mask_3.npy"],
... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
... transform_kwargs=dict(image_size=(128, 128)),
... num_classes=10,
Expand All @@ -145,7 +145,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"mask_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"mask_{i}.npy") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

Expand Down

0 comments on commit 6fb67cd

Please sign in to comment.